[pipeline/rpc] update outstanding mechanism | optimize dispatching strategy (#1497)

* support p2p communication with any type of object | pass test

* reconstruct pipeline schedule with p2p_v2.py(support communication with List[Any]) | pass test

* [engin/schedule] use p2p_v2 to recontruct pipeline_schedule

* [pipeline/rpc] implement a demo for PP with cuda rpc framework

* [pipeline/rpc] support interleaving | fix checkpoint bug | change logic when dispatch data in work_list to ensure steady 1F1B

* [pipeline/rpc] implement distributed optimizer | test with assert_close

* [pipeline/rpc] implement distributed optimizer | test with assert_close

* [pipeline/rpc] update outstanding mechanism | optimize dispatching strategy

* [pipeline/rpc] update outstanding mechanism | optimize dispatching strategy

* [pipeline/rpc] update outstanding mechanism | optimize dispatching strategy
This commit is contained in:
Kirigaya Kazuto
2022-08-26 14:04:23 +08:00
committed by GitHub
parent 0ed2f46131
commit 5a6fd71f90
5 changed files with 174 additions and 135 deletions

View File

@@ -68,7 +68,7 @@ class UniqueKey:
class WorkItem:
__slots__ = ('stage_id', 'phase', 'args', 'kwargs', 'output', 'refcount', 'microbatch_id', 'batch_id',
'num_microbatches')
'num_microbatches', 'forward_only')
stage_id: int
phase: Phase
@@ -81,6 +81,7 @@ class WorkItem:
batch_id: int
num_microbatches: int
forward_only: bool
def __init__(self,
stage_id,
@@ -91,6 +92,7 @@ class WorkItem:
microbatch_id,
batch_id,
num_microbatches,
forward_only,
refcount=0) -> None:
for attr_name in self.__slots__:
setattr(self, attr_name, locals()[attr_name])
@@ -129,36 +131,39 @@ class Worker:
pp_rank: int,
actual_stage_num: int,
num_microbatches: int,
max_outstanding: int,
use_1F1B: bool,
device: str,
checkpoint: bool = False) -> None:
super().__init__()
self.pp_rank = pp_rank
self.actual_stage_num = actual_stage_num
self.num_microbatches = num_microbatches
self.max_outstanding = max_outstanding
self.outstanding = 0
self.checkpoint = checkpoint
self.device = device
self.outstanding_range = self._initialize_outstanding_range(pp_rank, actual_stage_num, use_1F1B)
self.future_devices = None if device is None or device == 'cpu' else [device]
# variable and const for context managment
self.outstanding = 0
self.forward_times = 0
self.backward_times = 0
self.reset_key = UniqueKey(0, Phase.FORWARD)
# rref of other workers
self.pp_rank_to_worker_rref: Dict[int, PyRRef] = None
# topology info
self.producer_stage_ids: List[int] = None
self.consumer_stage_ids: List[int] = None
# module
# module partitions
self.module_partition = module_partition.to(device)
self.debug_list = [None] * num_microbatches
# container to maintain loop
self.microbatch_id_to_backward_cache: Dict[int, BackwardCache] = dict()
self.work_list: Dict[UniqueKey, WorkItem] = dict()
self.output_list: Dict[UniqueKey, WorkItem] = dict()
# Why must a Lock instead of RLock ?
# Because RLock cannot be pickled
# lock for the list
self.work_list_condition_lock = threading.Condition(threading.Lock())
self.output_list_condition_lock = threading.Condition(threading.Lock())
@@ -168,6 +173,15 @@ class Worker:
def _get_future_by_device(self):
return torch.futures.Future(devices=None if self.device in (None, 'cpu') else [self.device])
def _initialize_outstanding_range(self, pp_rank: int, actual_stage_num: int, use_1F1B: bool) -> Tuple[int]:
outstanding_range = None
if use_1F1B:
if pp_rank == actual_stage_num - 1:
outstanding_range = (0, 1)
else:
outstanding_range = (actual_stage_num, actual_stage_num)
return outstanding_range
def sync_global_worker_rrefs(self, pp_rank_to_worker_rref: Dict[int, PyRRef]) -> None:
assert self.pp_rank_to_worker_rref is None, f"in rank {self.pp_rank}, worker has sync global workers rrefs"
assert pp_rank_to_worker_rref is not None, "stage_to_workers must be a dict instead of None"
@@ -197,8 +211,15 @@ class Worker:
def get_parameter_gradients(self) -> List[torch.Tensor]:
return [p.grad for p in self.module_partition.parameters()]
def reset_pp_context(self):
self.forward_times = 0
self.backward_times = 0
self.outstanding = 0
self.microbatch_id_to_backward_cache.clear()
self.output_list.clear()
# just for first pp_rank
def set_input(self, microbatch_id: int, microbatch: Tuple[Any]):
def set_input(self, microbatch_id: int, microbatch: Tuple[Any], forward_only: bool):
with self.work_list_condition_lock:
assert self.consumer_stage_ids is not None
consumer_num = len(self.consumer_stage_ids)
@@ -207,11 +228,10 @@ class Worker:
args = [microbatch] if isinstance(microbatch, torch.Tensor) else microbatch
work_item = WorkItem(self.pp_rank, Phase.FORWARD, args, {}, output, microbatch_id, None,
self.num_microbatches, consumer_num)
self.num_microbatches, forward_only)
self.work_list[key] = work_item
color_debug(f'rank {self.pp_rank} receive data from dataloader', 'data dispatch', 'magenta')
self.work_list_condition_lock.notify_all()
# just for last pp_rank
@@ -224,24 +244,22 @@ class Worker:
grad_wrt_loss = torch.tensor(1, device=self.device)
work_item = WorkItem(self.pp_rank, Phase.BACKWARD, grad_wrt_loss, {}, output, microbatch_id, None,
self.num_microbatches, producer_num)
self.num_microbatches, False)
color_debug(f'rank {self.pp_rank} propose backward', 'data dispatch', 'magenta')
self.work_list[key] = work_item
self.work_list_condition_lock.notify_all()
def subscribe_producer(self, microbatch_id: int):
def subscribe_producer(self, microbatch_id: int, forward_only: bool):
"""
You should call this function asynchronously
"""
assert self.producer_stage_ids is not None
producer_num = len(self.producer_stage_ids)
consumer_num = len(self.consumer_stage_ids)
assert producer_num > 0, "only stage that has producers can subscribe producers"
stage_id = self.pp_rank
subscribe_forward_futures: List[Future] = [None] * producer_num
output = self._get_future_by_device()
@@ -259,9 +277,8 @@ class Worker:
producer_args = subscribe_forward_futures[i].wait()
args.extend(producer_args)
# TODO : not only args
work_item_from_producer = WorkItem(stage_id, Phase.FORWARD, args, {}, output, microbatch_id, None,
self.num_microbatches, consumer_num)
self.num_microbatches, forward_only)
color_debug(f'rank {self.pp_rank} get value {tensor_shape_list(args)} from fut', 'data dispatch', 'magenta')
# add work_item to work_list
@@ -279,13 +296,10 @@ class Worker:
You should call this function asynchronously
"""
assert self.producer_stage_ids is not None
producer_num = len(self.producer_stage_ids)
consumer_num = len(self.consumer_stage_ids)
assert consumer_num > 0, "only stage that has consumers can subscribe comsumers"
# TODO : is this right?
stage_id = self.pp_rank
subscribe_backward_futures: List[Future] = [None] * consumer_num
output = self._get_future_by_device()
@@ -305,7 +319,7 @@ class Worker:
# flatten args
work_item_from_consumer = WorkItem(stage_id, Phase.BACKWARD, args, {}, output, microbatch_id, None,
self.num_microbatches, producer_num)
self.num_microbatches, False)
color_debug(f'rank {self.pp_rank} get value {tensor_shape_list(args)} from fut', 'data dispatch', 'magenta')
@@ -341,32 +355,57 @@ class Worker:
while len(self.work_list) == 0:
self.work_list_condition_lock.wait()
# execute backward first (if backward phase in work_list)
select_work_list_key = None
for key in self.work_list:
work_item = self.work_list[key]
if work_item.phase == Phase.FORWARD and \
self.max_outstanding is not None and \
self.outstanding >= self.max_outstanding:
continue
else:
if select_work_list_key is not None and \
select_work_list_key.phase == Phase.FORWARD and \
key.phase == Phase.BACKWARD:
continue
# each stage must do Key(microbatch_id=0, phase=FORWARD) first
# before doing the operation, reset the context first
if self.reset_key in self.work_list:
self.reset_pp_context()
if select_work_list_key is None:
select_work_list_key = key
else:
phase_pair = (select_work_list_key.phase, key.phase)
# choose forward first
if phase_pair == (Phase.BACKWARD, Phase.FORWARD):
select_work_list_key = key
elif phase_pair == (Phase.FORWARD, Phase.BACKWARD):
continue
# choose work_item which has a smaller microbactch_id first
elif key.microbatch_id < select_work_list_key.microbatch_id:
select_work_list_key = key
# execute backward first (if backward phase in work_list)
pp_rank = self.pp_rank
actual_stage_num = self.actual_stage_num
num_microbatches = self.num_microbatches
is_last_stage = pp_rank == actual_stage_num - 1
select_work_list_key: UniqueKey = None
if self.outstanding_range:
if self.outstanding <= self.outstanding_range[0]:
target_phase = Phase.FORWARD
target_microbatch_id = self.forward_times
elif self.outstanding >= self.outstanding_range[1]:
target_phase = Phase.BACKWARD
target_microbatch_id = self.backward_times
else:
raise ValueError("outstanding_range[1] - outstanding_range[0] must be in [0, 1]")
target_key = UniqueKey(target_microbatch_id, target_phase)
if target_key in self.work_list:
select_work_list_key = target_key
# change outstanding_range at:
# 1. forward times reach actual_stage_num, this is the end of continuous forward
# 2. forward times reach num_microbatches, this is the end of 1F1B mode
if not is_last_stage and \
select_work_list_key is not None and \
select_work_list_key.phase == Phase.FORWARD:
if select_work_list_key.microbatch_id == actual_stage_num - 1:
outstanding_min = actual_stage_num - pp_rank - 1
outstanding_max = actual_stage_num - pp_rank
self.outstanding_range = (outstanding_min, outstanding_max)
elif select_work_list_key.microbatch_id == num_microbatches - 1:
self.outstanding_range = (0, 0)
else:
if self.forward_times < num_microbatches:
target_phase = Phase.FORWARD
target_microbatch_id = self.forward_times
else:
target_phase = Phase.BACKWARD
target_microbatch_id = self.backward_times
target_key = UniqueKey(target_microbatch_id, target_phase)
if target_key in self.work_list:
select_work_list_key = target_key
return select_work_list_key
@@ -375,15 +414,28 @@ class Worker:
args = work_item.args
kwargs = work_item.kwargs
microbatch_id = work_item.microbatch_id
forward_only = work_item.forward_only
consume_result = None
# if self.pp_rank == 0:
# print(f"I am rank_{self.pp_rank} microbatch_id : {microbatch_id}", work_item.phase, len(self.work_list))
# TODO : use process manager to acquire rank info later
is_first_stage = (self.pp_rank == 0)
is_last_stage = (self.pp_rank == self.actual_stage_num - 1)
# color_debug(f'rank_{self.pp_rank} enter consume', 'consume', 'blue')
# if self.pp_rank == 3:
# print(
# f'I am rank_{self.pp_rank} microbatch_id : {microbatch_id} {phase} {self._get_store_len()} | {self.outstanding} {self.outstanding_range}'
# )
if phase == Phase.FORWARD:
self.outstanding += 1
# remind its consumer to get data before forward
if not is_last_stage:
for stage_id in self.consumer_stage_ids:
consumer_worker_rref = self.pp_rank_to_worker_rref[stage_id]
consumer_worker_rref.remote().subscribe_producer(microbatch_id, forward_only)
self.forward_times += 1
if not forward_only:
self.outstanding += 1
# TODO : more elegant ?
for i in range(len(args)):
@@ -391,35 +443,46 @@ class Worker:
if isinstance(arg_obj, torch.Tensor) and not arg_obj.requires_grad:
args[i] = arg_obj.requires_grad_()
# TODO : use process manager to acquire rank info later
is_last_stage = (self.pp_rank == self.actual_stage_num - 1)
# last stage doesn't need to do checkpoint, for it will do backward instantly
if self.checkpoint and not is_last_stage:
if forward_only:
with torch.no_grad():
consume_result = self.module_partition(*args, **kwargs)
stage_outputs = None
stage_inputs = None
use_checkpoint = None
elif self.checkpoint and not is_last_stage:
with torch.no_grad():
consume_result = self.module_partition(*args, **kwargs)
stage_outputs = None
stage_inputs = args
self.microbatch_id_to_backward_cache[microbatch_id] = BackwardCache(stage_inputs,
stage_outputs,
checkpoint=True)
use_checkpoint = True
else:
consume_result = self.module_partition(*args, **kwargs)
stage_outputs = consume_result
stage_inputs = args
use_checkpoint = False
if not forward_only:
self.microbatch_id_to_backward_cache[microbatch_id] = BackwardCache(stage_inputs,
stage_outputs,
checkpoint=False)
checkpoint=use_checkpoint)
consume_result = [consume_result] if isinstance(consume_result, torch.Tensor) else consume_result
# if it is the last stage, trigger backward automatic
if is_last_stage:
self._begin_backward(microbatch_id)
# if not forward_only, do the backward
if not forward_only:
if is_last_stage: # if it is the last stage, trigger backward automatic
self._begin_backward(microbatch_id)
elif phase == Phase.BACKWARD:
# remind its producer to get data before backward
if not is_first_stage:
for stage_id in self.producer_stage_ids:
producer_worker_rref = self.pp_rank_to_worker_rref[stage_id]
producer_worker_rref.remote().subscribe_consumer(microbatch_id)
self.backward_times += 1
self.outstanding -= 1
assert microbatch_id in self.microbatch_id_to_backward_cache, f"microbatch_id {microbatch_id} not in backward cache"
backward_cache = self.microbatch_id_to_backward_cache.pop(microbatch_id)
@@ -445,6 +508,9 @@ class Worker:
return consume_result
def _get_store_len(self):
return f'work_list:{len(self.work_list)} output_list:{len(self.output_list)} backward_cache:{len(self.microbatch_id_to_backward_cache)}'
# do the main loop to consume ready_list
def _work_loop(self):
# for init
@@ -461,7 +527,7 @@ class Worker:
work_item = self.work_list.pop(work_item_key)
color_debug(
f'rank {self.pp_rank} get a key : {work_item_key} work_item args: {tensor_shape_list(work_item.args)}',
f'rank {self.pp_rank} get a key : {work_item_key} work_item args: {tensor_shape_list(work_item.args)} {self._get_store_len()}',
'work loop', 'green')
with self.output_list_condition_lock:
@@ -472,7 +538,7 @@ class Worker:
consume_result = self._consume_work_item_by_phase(work_item)
color_debug(
f'rank_{self.pp_rank} [{work_item.phase}] finish consuming, result is {tensor_shape_list(consume_result)}',
f'rank_{self.pp_rank} [{work_item.phase}] finish consuming, result is {tensor_shape_list(consume_result)} {self._get_store_len()}',
'work loop', 'green')
work_item.output.set_result(consume_result)
@@ -489,9 +555,6 @@ class Worker:
self.optimizer.zero_grad()
# TODO
# 1. chunk
# 2. checkpoint
class PipelineEngineBase(ABC, nn.Module):
def __init__(self,
@@ -499,19 +562,18 @@ class PipelineEngineBase(ABC, nn.Module):
stage_num,
num_microbatches,
device: str,
max_outstanding=None,
use_1F1B=False,
chunk: int = 1,
use_interleave: bool = False,
checkpoint: bool = False) -> None:
super().__init__()
self.module_partitions: List[nn.Module] = module_partitions
self.chunk = chunk
self.num_microbatches = num_microbatches
self.device = device
self.max_outstanding = max_outstanding
self.use_1F1B = use_1F1B
self.stage_num = stage_num
self.checkpoint = checkpoint
self.use_interleave = use_interleave
self.use_interleave = chunk > 1
self.pp_rank_to_worker_rref: Dict[int, PyRRef] = dict()
@@ -547,7 +609,7 @@ class PipelineEngineBase(ABC, nn.Module):
def _init_worker(self):
actual_stage_num = self._get_actual_stage_num()
max_outstanding = self.max_outstanding
use_1F1B = self.use_1F1B
checkpoint = self.checkpoint
num_microbatches = self.num_microbatches
device = self.device
@@ -560,8 +622,7 @@ class PipelineEngineBase(ABC, nn.Module):
self.pp_rank_to_worker_rref[pp_rank] = rpc.remote(rpc_worker_id,
Worker,
args=(module_partition, pp_rank, actual_stage_num,
num_microbatches, max_outstanding, device,
checkpoint))
num_microbatches, use_1F1B, device, checkpoint))
# let each worker know global worker rref (include itself)
for pp_rank in range(actual_stage_num):
@@ -585,46 +646,55 @@ class PipelineEngineBase(ABC, nn.Module):
grads[stage_id].append(grad)
return grads
def forward_backward(self, batch: torch.Tensor):
first_stage_worker = self.pp_rank_to_worker_rref[0]
microbatch_size = len(batch) // self.num_microbatches
def forward_backward(self, batch: torch.Tensor, forward_only: bool = False):
num_microbatches = self.num_microbatches
microbatch_size = len(batch) // num_microbatches
actual_stage_num = self._get_actual_stage_num()
microbatch_iter = range(self.num_microbatches)
first_stage_worker = self.pp_rank_to_worker_rref[0]
last_worker_rref = self.pp_rank_to_worker_rref[actual_stage_num - 1]
microbatch_iter = range(num_microbatches)
if use_progress:
microbatch_iter = tqdm(microbatch_iter)
ret_future: List[Future] = [None] * num_microbatches
from time import sleep
for microbatch_id in microbatch_iter:
microbatch = batch[microbatch_size * microbatch_id:microbatch_size * (microbatch_id + 1)]
# forward subscribe asynchronously
for pp_rank in range(1, actual_stage_num, 1):
worker_rref = self.pp_rank_to_worker_rref[pp_rank]
worker_rref.rpc_async().subscribe_producer(microbatch_id)
# backward subscribe asynchronously
for pp_rank in range(actual_stage_num - 2, -1, -1):
worker_rref = self.pp_rank_to_worker_rref[pp_rank]
worker_rref.rpc_async().subscribe_consumer(microbatch_id)
# control data input speed
# to prevent exceed of wait limitations
if microbatch_id >= actual_stage_num:
if forward_only or not self.use_1F1B:
ret_future[microbatch_id - actual_stage_num].wait()
else:
key = UniqueKey(microbatch_id - actual_stage_num, Phase.BACKWARD)
first_stage_worker.rpc_sync().get_output_by_key(key)
# run one microbatch
first_stage_worker.rpc_sync().set_input(microbatch_id, microbatch)
first_stage_worker.rpc_sync().set_input(microbatch_id, microbatch, forward_only)
key = UniqueKey(microbatch_id, Phase.FORWARD)
ret_future[microbatch_id] = last_worker_rref.rpc_async().get_output_by_key(key)
# wait forward
# TODO : all the node to output
forward_result = None
last_worker_rref = self.pp_rank_to_worker_rref[actual_stage_num - 1]
for microbatch_id in range(self.num_microbatches):
key = UniqueKey(microbatch_id, Phase.FORWARD)
ret = last_worker_rref.rpc_sync().get_output_by_key(key)
ret = ret_future[microbatch_id].wait()
if forward_result is None:
forward_result = [[]] * len(ret)
for i in range(len(forward_result)):
forward_result[i].append(ret[i])
# wait for last backward in rank0
key = UniqueKey(self.num_microbatches - 1, Phase.BACKWARD)
first_stage_worker.rpc_sync().get_output_by_key(key)
if not forward_only:
key = UniqueKey(self.num_microbatches - 1, Phase.BACKWARD)
first_stage_worker.rpc_sync().get_output_by_key(key)
return forward_result
def initialize_optimizer(self, optimizer_class: type, **kwargs):
@@ -654,11 +724,9 @@ class FillDrainPipelineEngine(PipelineEngineBase):
num_microbatches: int,
device: str,
chunk: int = 1,
use_interleave: bool = False,
checkpoint: bool = False) -> None:
max_outstanding = None
super().__init__(module_partitions, stage_num, num_microbatches, device, max_outstanding, chunk, use_interleave,
checkpoint)
use_1F1B = False
super().__init__(module_partitions, stage_num, num_microbatches, device, use_1F1B, chunk, checkpoint)
class OneFOneBPipelineEngine(PipelineEngineBase):
@@ -668,11 +736,7 @@ class OneFOneBPipelineEngine(PipelineEngineBase):
stage_num: int,
num_microbatches: int,
device: str,
max_outstanding=None,
chunk: int = 1,
use_interleave: bool = False,
checkpoint: bool = False) -> None:
if max_outstanding is None:
max_outstanding = len(module_partitions)
super().__init__(module_partitions, stage_num, num_microbatches, device, max_outstanding, chunk, use_interleave,
checkpoint)
use_1F1B = True
super().__init__(module_partitions, stage_num, num_microbatches, device, use_1F1B, chunk, checkpoint)