[pipeline/rpc] support interleaving | fix checkpoint bug | change logic when dispatch data in work_list to ensure steady 1F1B (#1483)

* support p2p communication with any type of object | pass test

* reconstruct pipeline schedule with p2p_v2.py(support communication with List[Any]) | pass test

* [engin/schedule] use p2p_v2 to recontruct pipeline_schedule

* [pipeline/rpc] implement a demo for PP with cuda rpc framework

* [pipeline/rpc] support interleaving | fix checkpoint bug | change logic when dispatch data in work_list to ensure steady 1F1B
This commit is contained in:
Kirigaya Kazuto
2022-08-24 11:19:46 +08:00
committed by GitHub
parent d6e3dca436
commit a6c8749198
3 changed files with 366 additions and 139 deletions

View File

@@ -1,7 +1,7 @@
import threading
from enum import Enum
from typing import List, Any, Tuple, Dict
from abc import ABC, abstractmethod
from abc import ABC
import torch
from torch import nn
@@ -18,9 +18,8 @@ use_color_debug = False
use_progress = False
# TODO:
# 1. design a unique_key without node.name (Maybe I can use combination of microbatch_id and stage_id)
# 2. use waiting list to contain the uncomplete WorkItem
# 3. think about the representation of the order of args and kwargs
# 1. replace world_size with other parameters
# 2. adjust to args and kwargs
def color_debug(text, prefix=' ', color='blue'):
@@ -126,33 +125,32 @@ class RemoteOptimizer:
class Worker:
def __init__(self,
cur_rank_module: nn.Module,
rank: int,
world_size: int,
module_partition: nn.Module,
pp_rank: int,
actual_stage_num: int,
num_microbatches: int,
max_outstanding: int,
device: str,
checkpoint: bool = False) -> None:
super().__init__()
self.rank = rank
self.world_size = world_size
self.pp_rank = pp_rank
self.actual_stage_num = actual_stage_num
self.num_microbatches = num_microbatches
self.max_outstanding = max_outstanding
self.outstanding = 0
self.checkpoint = checkpoint
if device == 'cuda':
device = f'cuda:{rank}'
self.device = device
self.future_devices = None if device is None or device == 'cpu' else [device]
self.stage_to_worker_rref: Dict[int, PyRRef] = None
self.pp_rank_to_worker_rref: Dict[int, PyRRef] = None
self.producer_stage_ids: List[int] = None
self.consumer_stage_ids: List[int] = None
# module
self.cur_rank_module = cur_rank_module.to(device)
self.module_partition = module_partition.to(device)
self.debug_list = [None] * num_microbatches
self.microbatch_id_to_backward_cache: Dict[int, BackwardCache] = dict()
@@ -164,16 +162,16 @@ class Worker:
self.work_list_condition_lock = threading.Condition(threading.Lock())
self.output_list_condition_lock = threading.Condition(threading.Lock())
self.main_loop_thread = threading.Thread(target=self._work_loop, name=f'rank_{rank}', daemon=True)
self.main_loop_thread = threading.Thread(target=self._work_loop, name=f'rank_{pp_rank}', daemon=True)
self.main_loop_thread.start()
def _get_future_by_device(self):
return torch.futures.Future(devices=None if self.device in (None, 'cpu') else [self.device])
def sync_global_worker_rrefs(self, stage_to_worker_rref: Dict[int, PyRRef]) -> None:
assert self.stage_to_worker_rref is None, f"in rank {self.rank}, worker has sync global workers rrefs"
assert stage_to_worker_rref is not None, "stage_to_workers must be a dict instead of None"
self.stage_to_worker_rref = stage_to_worker_rref
def sync_global_worker_rrefs(self, pp_rank_to_worker_rref: Dict[int, PyRRef]) -> None:
assert self.pp_rank_to_worker_rref is None, f"in rank {self.pp_rank}, worker has sync global workers rrefs"
assert pp_rank_to_worker_rref is not None, "stage_to_workers must be a dict instead of None"
self.pp_rank_to_worker_rref = pp_rank_to_worker_rref
def get_output_by_key(self, key: UniqueKey) -> Any:
with self.output_list_condition_lock:
@@ -183,7 +181,7 @@ class Worker:
output_work_item = self.output_list[key]
output = output_work_item.output.wait()
# color_debug(f'rank {self.rank}, output {type(output)}', 'get output', 'red')
# color_debug(f'rank {self.pp_rank}, output {type(output)}', 'get output', 'red')
output_work_item.refcount += 1
# all consumers have been satisfied, the work_item can be released
@@ -193,8 +191,13 @@ class Worker:
return output
# just for first rank
# TODO : input is args kwargs
def get_parameters(self) -> List[torch.Tensor]:
return [p for p in self.module_partition.parameters()]
def get_parameter_gradients(self) -> List[torch.Tensor]:
return [p.grad for p in self.module_partition.parameters()]
# just for first pp_rank
def set_input(self, microbatch_id: int, microbatch: Tuple[Any]):
with self.work_list_condition_lock:
assert self.consumer_stage_ids is not None
@@ -203,16 +206,15 @@ class Worker:
output = self._get_future_by_device()
args = [microbatch] if isinstance(microbatch, torch.Tensor) else microbatch
work_item = WorkItem(self.rank, Phase.FORWARD, args, {}, output, microbatch_id, None, self.num_microbatches,
consumer_num)
work_item = WorkItem(self.pp_rank, Phase.FORWARD, args, {}, output, microbatch_id, None,
self.num_microbatches, consumer_num)
self.work_list[key] = work_item
color_debug(f'rank {self.rank} receive data from dataloader', 'data dispatch', 'magenta')
color_debug(f'rank {self.pp_rank} receive data from dataloader', 'data dispatch', 'magenta')
self.work_list_condition_lock.notify_all()
# just for last rank
# TODO : write a function to add gradient to work_list and see if there is contradictory
# just for last pp_rank
def _begin_backward(self, microbatch_id: int):
with self.work_list_condition_lock:
assert self.producer_stage_ids is not None
@@ -221,10 +223,10 @@ class Worker:
output = self._get_future_by_device()
grad_wrt_loss = torch.tensor(1, device=self.device)
work_item = WorkItem(self.rank, Phase.BACKWARD, grad_wrt_loss, {}, output, microbatch_id, None,
work_item = WorkItem(self.pp_rank, Phase.BACKWARD, grad_wrt_loss, {}, output, microbatch_id, None,
self.num_microbatches, producer_num)
color_debug(f'rank {self.rank} propose backward', 'data dispatch', 'magenta')
color_debug(f'rank {self.pp_rank} propose backward', 'data dispatch', 'magenta')
self.work_list[key] = work_item
self.work_list_condition_lock.notify_all()
@@ -238,7 +240,7 @@ class Worker:
consumer_num = len(self.consumer_stage_ids)
assert producer_num > 0, "only stage that has producers can subscribe producers"
stage_id = self.rank
stage_id = self.pp_rank
subscribe_forward_futures: List[Future] = [None] * producer_num
output = self._get_future_by_device()
@@ -246,10 +248,10 @@ class Worker:
for i in range(producer_num):
producer_stage_id = self.producer_stage_ids[i]
producer_output_key = UniqueKey(microbatch_id, Phase.FORWARD)
producer_worker_rref = self.stage_to_worker_rref[producer_stage_id]
producer_worker_rref = self.pp_rank_to_worker_rref[producer_stage_id]
subscribe_forward_futures[i] = producer_worker_rref.rpc_async().get_output_by_key(producer_output_key)
color_debug(f'rank {self.rank} get {len(subscribe_forward_futures)} futs from its producer', 'data dispatch',
color_debug(f'rank {self.pp_rank} get {len(subscribe_forward_futures)} futs from its producer', 'data dispatch',
'magenta')
args = []
@@ -261,14 +263,14 @@ class Worker:
work_item_from_producer = WorkItem(stage_id, Phase.FORWARD, args, {}, output, microbatch_id, None,
self.num_microbatches, consumer_num)
color_debug(f'rank {self.rank} get value {tensor_shape_list(args)} from fut', 'data dispatch', 'magenta')
color_debug(f'rank {self.pp_rank} get value {tensor_shape_list(args)} from fut', 'data dispatch', 'magenta')
# add work_item to work_list
with self.work_list_condition_lock:
key = UniqueKey(microbatch_id, Phase.FORWARD)
assert key not in self.work_list
self.work_list[key] = work_item_from_producer
color_debug(
f'rank_{self.rank} load a new task to its work_list {key} {work_item_from_producer.phase} data: {tensor_shape_list(work_item_from_producer.args)}',
f'rank_{self.pp_rank} load a new task to its work_list {key} {work_item_from_producer.phase} data: {tensor_shape_list(work_item_from_producer.args)}',
'data dispatch', 'magenta')
self.work_list_condition_lock.notify_all()
@@ -282,18 +284,18 @@ class Worker:
assert consumer_num > 0, "only stage that has consumers can subscribe comsumers"
# TODO : is this right?
stage_id = self.rank
stage_id = self.pp_rank
subscribe_backward_futures: List[Future] = [None] * consumer_num
output = self._get_future_by_device()
color_debug(f'rank {self.rank} get {len(subscribe_backward_futures)} futs from its consumer', 'data dispatch',
'magenta')
color_debug(f'rank {self.pp_rank} get {len(subscribe_backward_futures)} futs from its consumer',
'data dispatch', 'magenta')
for i in range(consumer_num):
consumer_stage_id = self.consumer_stage_ids[i]
consumer_output_key = UniqueKey(microbatch_id, Phase.BACKWARD)
consumer_worker_rref = self.stage_to_worker_rref[consumer_stage_id]
consumer_worker_rref = self.pp_rank_to_worker_rref[consumer_stage_id]
subscribe_backward_futures[i] = consumer_worker_rref.rpc_async().get_output_by_key(consumer_output_key)
args = []
@@ -305,7 +307,7 @@ class Worker:
work_item_from_consumer = WorkItem(stage_id, Phase.BACKWARD, args, {}, output, microbatch_id, None,
self.num_microbatches, producer_num)
color_debug(f'rank {self.rank} get value {tensor_shape_list(args)} from fut', 'data dispatch', 'magenta')
color_debug(f'rank {self.pp_rank} get value {tensor_shape_list(args)} from fut', 'data dispatch', 'magenta')
# add work_item to work_list
with self.work_list_condition_lock:
@@ -313,13 +315,12 @@ class Worker:
assert key not in self.work_list
self.work_list[key] = work_item_from_consumer
color_debug(
f'rank_{self.rank} load a new task to its work_list {key} {work_item_from_consumer.phase} data: {tensor_shape_list(work_item_from_consumer.args)}',
f'rank_{self.pp_rank} load a new task to its work_list {key} {work_item_from_consumer.phase} data: {tensor_shape_list(work_item_from_consumer.args)}',
'data dispatch', 'magenta')
self.work_list_condition_lock.notify_all()
# TODO : fit in any type of partition of network
def _get_producer_consumer(self) -> None:
rank = self.rank
rank = self.pp_rank
assert self.producer_stage_ids is None, f"all the producers of rank {rank} has been subscribed"
assert self.consumer_stage_ids is None, f"all the consumers of rank {rank} has been subscribed"
@@ -332,34 +333,41 @@ class Worker:
next_rank = rank + 1
if prev_rank >= 0:
self.producer_stage_ids.append(prev_rank)
if next_rank <= self.world_size - 1:
if next_rank <= self.actual_stage_num - 1:
self.consumer_stage_ids.append(next_rank)
def _skip_forward(self, work_item_phase: Phase) -> bool:
if work_item_phase == Phase.FORWARD and \
self.max_outstanding is not None and \
self.outstanding >= self.max_outstanding:
return True
return False
def _get_work_item_key(self) -> UniqueKey:
with self.work_list_condition_lock:
while len(self.work_list) == 0:
self.work_list_condition_lock.wait()
# execute backward first (if backward phase in work_list)
select_work_list_key = None
for key in self.work_list:
work_item = self.work_list[key]
if work_item.phase == Phase.BACKWARD:
return key
if self._skip_forward(work_item.phase):
if work_item.phase == Phase.FORWARD and \
self.max_outstanding is not None and \
self.outstanding >= self.max_outstanding:
continue
else:
select_work_list_key = key
if select_work_list_key is not None and \
select_work_list_key.phase == Phase.FORWARD and \
key.phase == Phase.BACKWARD:
continue
if select_work_list_key is None:
select_work_list_key = key
else:
phase_pair = (select_work_list_key.phase, key.phase)
# choose forward first
if phase_pair == (Phase.BACKWARD, Phase.FORWARD):
select_work_list_key = key
elif phase_pair == (Phase.FORWARD, Phase.BACKWARD):
continue
# choose work_item which has a smaller microbactch_id first
elif key.microbatch_id < select_work_list_key.microbatch_id:
select_work_list_key = key
return select_work_list_key
def _consume_work_item_by_phase(self, work_item: WorkItem):
@@ -369,7 +377,10 @@ class Worker:
microbatch_id = work_item.microbatch_id
consume_result = None
# color_debug(f'rank_{self.rank} enter consume', 'consume', 'blue')
# if self.pp_rank == 0:
# print(f"I am rank_{self.pp_rank} microbatch_id : {microbatch_id}", work_item.phase, len(self.work_list))
# color_debug(f'rank_{self.pp_rank} enter consume', 'consume', 'blue')
if phase == Phase.FORWARD:
self.outstanding += 1
@@ -381,19 +392,20 @@ class Worker:
args[i] = arg_obj.requires_grad_()
# TODO : use process manager to acquire rank info later
is_last_stage = len(self.consumer_stage_ids) == 0
is_last_stage = (self.pp_rank == self.actual_stage_num - 1)
# last stage doesn't need to do checkpoint, for it will do backward instantly
if self.checkpoint and not is_last_stage:
with torch.no_grad():
consume_result = self.cur_rank_module(*args, **kwargs)
consume_result = self.module_partition(*args, **kwargs)
stage_outputs = None
stage_inputs = args
self.microbatch_id_to_backward_cache[microbatch_id] = BackwardCache(stage_inputs,
stage_outputs,
checkpoint=True)
else:
# TODO : replace with *args, **kwargs and ensure the consume_result is a tuple
consume_result = self.cur_rank_module(*args, **kwargs)
consume_result = self.module_partition(*args, **kwargs)
stage_outputs = consume_result
stage_inputs = args
self.microbatch_id_to_backward_cache[microbatch_id] = BackwardCache(stage_inputs,
@@ -415,15 +427,13 @@ class Worker:
stage_inputs = backward_cache.stage_inputs
grad_tensors = args
# color_debug(f'rank_{self.rank} before backward', 'consume', 'yellow')
use_checkpoint = backward_cache.checkpoint
if self.checkpoint:
stage_outputs = [self.cur_rank_module(*stage_inputs)]
if use_checkpoint:
stage_outputs = [self.module_partition(*stage_inputs)]
autograd.backward(stage_outputs, grad_tensors=grad_tensors)
# color_debug(f'rank_{self.rank} after backward', 'consume', 'yellow')
# collect grad of input tensor
consume_result = []
for input_node in stage_inputs:
@@ -453,7 +463,7 @@ class Worker:
work_item = self.work_list.pop(work_item_key)
color_debug(
f'rank {self.rank} get a key : {work_item_key} work_item args: {tensor_shape_list(work_item.args)}',
f'rank {self.pp_rank} get a key : {work_item_key} work_item args: {tensor_shape_list(work_item.args)}',
'work loop', 'green')
with self.output_list_condition_lock:
@@ -464,11 +474,8 @@ class Worker:
consume_result = self._consume_work_item_by_phase(work_item)
color_debug(
f'rank_{self.rank} [{work_item.phase}] finish consuming, result is {tensor_shape_list(consume_result)}',
f'rank_{self.pp_rank} [{work_item.phase}] finish consuming, result is {tensor_shape_list(consume_result)}',
'work loop', 'green')
# if work_item.stage_id == 1 and work_item.phase == Phase.BACKWARD:
# from time import sleep
# sleep(5)
work_item.output.set_result(consume_result)
@@ -479,11 +486,11 @@ class PipelineEngineBase(ABC, nn.Module):
def __init__(self,
module_partitions,
chunk,
world_size,
stage_num,
num_microbatches,
device: str,
max_outstanding=None,
chunk: int = 1,
use_interleave: bool = False,
checkpoint: bool = False) -> None:
super().__init__()
@@ -492,55 +499,86 @@ class PipelineEngineBase(ABC, nn.Module):
self.num_microbatches = num_microbatches
self.device = device
self.max_outstanding = max_outstanding
self.world_size = world_size
self.stage_num = stage_num
self.checkpoint = checkpoint
self.use_interleave = use_interleave
self.stage_to_worker_rref: Dict[int, PyRRef] = dict()
self.pp_rank_to_worker_rref: Dict[int, PyRRef] = dict()
self._check_argument()
self._create_pp_rank_to_rpc_worker_id()
self._init_worker()
def _check_argument(self):
self.virtual_stage_num = self.stage_num * self.chunk
assert self.stage_num <= torch.cuda.device_count(), "stage_num must be smaller than device count!"
assert self.virtual_stage_num == len(
self.module_partitions), "stage_num * chunk must be equal to length of model partition!"
if self.use_interleave:
assert self.num_microbatches % self.stage_num == 0, "if you use interleaving strategy, make sure 'num_microbatches' is a multiple of stage_num!"
def _get_actual_stage_num(self):
return self.stage_num if self.chunk == 1 else self.virtual_stage_num
def _create_pp_rank_to_rpc_worker_id(self):
"""create a map from model partition to stage_id, which is useful when use_interleave is True.
e.g. If a model is splited into 4 parts, which means len(self.module_partitions) == 3.
stage_num is 2, chunk is 2, then pp_rank_to_rpc_worker_id = [0, 1, 0, 1], that means first and third part
of partitions will be moved to device 0 and the others to device 1
"""
stage_num = self.stage_num
actual_stage_num = self._get_actual_stage_num()
self.pp_rank_to_rpc_worker_id = [0] * actual_stage_num
for pp_rank in range(actual_stage_num):
self.pp_rank_to_rpc_worker_id[pp_rank] = pp_rank % stage_num
def _init_worker(self):
world_size = self.world_size
actual_stage_num = self._get_actual_stage_num()
max_outstanding = self.max_outstanding
checkpoint = self.checkpoint
num_microbatches = self.num_microbatches
device = self.device
# TODO : world size is correct ?
for rank in range(world_size):
cur_rank_module = self.module_partitions[rank]
self.stage_to_worker_rref[rank] = rpc.remote(rank,
Worker,
args=(cur_rank_module, rank, world_size, num_microbatches,
max_outstanding, device, checkpoint))
for pp_rank in range(actual_stage_num):
module_partition = self.module_partitions[pp_rank]
rpc_worker_id = self.pp_rank_to_rpc_worker_id[pp_rank]
if device[:4] == 'cuda':
device = f'cuda:{rpc_worker_id}'
self.pp_rank_to_worker_rref[pp_rank] = rpc.remote(rpc_worker_id,
Worker,
args=(module_partition, pp_rank, actual_stage_num,
num_microbatches, max_outstanding, device,
checkpoint))
# let each worker know global worker rref (include itself)
for rank in range(world_size):
self.stage_to_worker_rref[rank].rpc_sync().sync_global_worker_rrefs(self.stage_to_worker_rref)
for pp_rank in range(actual_stage_num):
self.pp_rank_to_worker_rref[pp_rank].rpc_sync().sync_global_worker_rrefs(self.pp_rank_to_worker_rref)
@abstractmethod
def forward_backward(self):
pass
def remote_parameters(self) -> Dict[int, List[torch.Tensor]]:
parameters = {}
for stage_id in self.pp_rank_to_worker_rref:
parameters[stage_id] = []
worker_rref = self.pp_rank_to_worker_rref[stage_id]
for p in worker_rref.rpc_sync().get_parameters():
parameters[stage_id].append(p)
return parameters
def remote_grad(self) -> Dict[int, List[torch.Tensor]]:
grads = {}
for stage_id in self.pp_rank_to_worker_rref:
grads[stage_id] = []
worker_rref = self.pp_rank_to_worker_rref[stage_id]
for grad in worker_rref.rpc_sync().get_parameter_gradients():
grads[stage_id].append(grad)
return grads
class FillDrainPipelineEngine(PipelineEngineBase):
def __init__(self,
module_partitions,
chunk,
world_size,
num_microbatches,
device: str,
max_outstanding=None,
use_interleave: bool = False,
checkpoint: bool = False) -> None:
super().__init__(module_partitions, chunk, world_size, num_microbatches, device, max_outstanding,
use_interleave, checkpoint)
# TODO : adjust to args and kwargs
def forward_backward(self, batch: torch.Tensor):
first_stage_worker = self.stage_to_worker_rref[0]
first_stage_worker = self.pp_rank_to_worker_rref[0]
microbatch_size = len(batch) // self.num_microbatches
actual_stage_num = self._get_actual_stage_num()
microbatch_iter = range(self.num_microbatches)
if use_progress:
@@ -550,31 +588,63 @@ class FillDrainPipelineEngine(PipelineEngineBase):
microbatch = batch[microbatch_size * microbatch_id:microbatch_size * (microbatch_id + 1)]
# forward subscribe asynchronously
for rank in range(1, self.world_size, 1):
worker_rref = self.stage_to_worker_rref[rank]
for pp_rank in range(1, actual_stage_num, 1):
worker_rref = self.pp_rank_to_worker_rref[pp_rank]
worker_rref.rpc_async().subscribe_producer(microbatch_id)
# backward subscribe asynchronously
for rank in range(self.world_size - 2, -1, -1):
worker_rref = self.stage_to_worker_rref[rank]
for pp_rank in range(actual_stage_num - 2, -1, -1):
worker_rref = self.pp_rank_to_worker_rref[pp_rank]
worker_rref.rpc_async().subscribe_consumer(microbatch_id)
# run one microbatch
first_stage_worker.rpc_sync().set_input(microbatch_id, microbatch)
# wait forward
# TODO : all the node to output
forward_result = None
last_worker_rref = self.pp_rank_to_worker_rref[actual_stage_num - 1]
for microbatch_id in range(self.num_microbatches):
key = UniqueKey(microbatch_id, Phase.FORWARD)
ret = last_worker_rref.rpc_sync().get_output_by_key(key)
if forward_result is None:
forward_result = [[]] * len(ret)
for i in range(len(forward_result)):
forward_result[i].append(ret[i])
class OneFOneBPipelineEngine(FillDrainPipelineEngine):
# wait for last backward in rank0
key = UniqueKey(self.num_microbatches - 1, Phase.BACKWARD)
first_stage_worker.rpc_sync().get_output_by_key(key)
return forward_result
class FillDrainPipelineEngine(PipelineEngineBase):
def __init__(self,
module_partitions,
chunk,
world_size,
num_microbatches,
module_partitions: List[nn.Module],
stage_num: int,
num_microbatches: int,
device: str,
chunk: int = 1,
use_interleave: bool = False,
checkpoint: bool = False) -> None:
max_outstanding = None
super().__init__(module_partitions, stage_num, num_microbatches, device, max_outstanding, chunk, use_interleave,
checkpoint)
class OneFOneBPipelineEngine(PipelineEngineBase):
def __init__(self,
module_partitions: List[nn.Module],
stage_num: int,
num_microbatches: int,
device: str,
max_outstanding=None,
chunk: int = 1,
use_interleave: bool = False,
checkpoint: bool = False) -> None:
if max_outstanding is None:
max_outstanding = world_size
super().__init__(module_partitions, chunk, world_size, num_microbatches, device, max_outstanding,
use_interleave, checkpoint)
max_outstanding = len(module_partitions)
super().__init__(module_partitions, stage_num, num_microbatches, device, max_outstanding, chunk, use_interleave,
checkpoint)