[pipeline/tuning] improve dispatch performance both time and space cost (#1544)

This commit is contained in:
Kirigaya Kazuto 2022-09-07 19:01:06 +08:00 committed by GitHub
parent 4f59693207
commit 6159d45417
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 194 additions and 155 deletions

View File

@ -1,7 +1,8 @@
import threading import threading
from enum import Enum from enum import Enum
from typing import List, Any, Tuple, Dict from typing import List, Any, Tuple, Dict, Callable
from abc import ABC from abc import ABC
import sys
import torch import torch
from torch import nn from torch import nn
@ -11,6 +12,7 @@ from torch._C._distributed_rpc import PyRRef
from torch import autograd from torch import autograd
from torch import optim from torch import optim
from tqdm import tqdm from tqdm import tqdm
from time import time
from colorama import Back, Style from colorama import Back, Style
@ -30,6 +32,10 @@ def color_debug(text, prefix=' ', color='blue'):
def tensor_shape_list(tensors): def tensor_shape_list(tensors):
if tensors is None:
return None
if isinstance(tensors, (int, float)):
return tensors
if isinstance(tensors, torch.Tensor): if isinstance(tensors, torch.Tensor):
return tensors.shape return tensors.shape
shapes = [] shapes = []
@ -41,6 +47,25 @@ def tensor_shape_list(tensors):
return shapes return shapes
def get_real_args(args):
if isinstance(args, torch.Tensor):
return args
elif isinstance(args, list):
real_args = []
for arg in args:
if isinstance(arg, Future):
value = arg.wait()
else:
value = arg
if isinstance(value, list):
real_args.extend(value)
else:
real_args.append(value)
return real_args
else:
raise TypeError(f"Expect receive tensor or list, but receive {type(args)}")
class Phase(Enum): class Phase(Enum):
FORWARD = 0 FORWARD = 0
BACKWARD = 1 BACKWARD = 1
@ -112,18 +137,6 @@ class BackwardCache:
setattr(self, arg_name, locals()[arg_name]) setattr(self, arg_name, locals()[arg_name])
class RemoteExecutor:
def __init__(self) -> None:
pass
class RemoteOptimizer:
def __init__(self) -> None:
pass
class Worker: class Worker:
def __init__(self, def __init__(self,
@ -133,6 +146,7 @@ class Worker:
num_microbatches: int, num_microbatches: int,
use_1F1B: bool, use_1F1B: bool,
device: str, device: str,
criterion: Callable = None,
checkpoint: bool = False) -> None: checkpoint: bool = False) -> None:
super().__init__() super().__init__()
self.pp_rank = pp_rank self.pp_rank = pp_rank
@ -140,7 +154,8 @@ class Worker:
self.num_microbatches = num_microbatches self.num_microbatches = num_microbatches
self.checkpoint = checkpoint self.checkpoint = checkpoint
self.device = device self.device = device
self.outstanding_range = self._initialize_outstanding_range(pp_rank, actual_stage_num, use_1F1B) self.use_1F1B = use_1F1B
self._initialize_outstanding_range()
# variable and const for context managment # variable and const for context managment
self.outstanding = 0 self.outstanding = 0
@ -157,30 +172,39 @@ class Worker:
# module partitions # module partitions
self.module_partition = module_partition.to(device) self.module_partition = module_partition.to(device)
if criterion:
assert callable(criterion)
self.criterion = criterion
# container to maintain loop # container to maintain loop
self.microbatch_id_to_backward_cache: Dict[int, BackwardCache] = dict() self.microbatch_id_to_backward_cache: Dict[int, BackwardCache] = dict()
self.microbatch_id_to_labels: Dict[int, Any] = dict()
self.work_list: Dict[UniqueKey, WorkItem] = dict() self.work_list: Dict[UniqueKey, WorkItem] = dict()
self.output_list: Dict[UniqueKey, WorkItem] = dict() self.output_list: Dict[UniqueKey, WorkItem] = dict()
# lock for the list # lock for the list
self.work_list_condition_lock = threading.Condition(threading.Lock()) self.work_list_condition_lock = threading.Condition(threading.Lock())
self.output_list_condition_lock = threading.Condition(threading.Lock()) self.output_list_condition_lock = threading.Condition(threading.Lock())
self.label_lock = threading.Condition(threading.Lock())
self.step_lock = threading.Lock()
self.step_lock.acquire()
# main loop
self.main_loop_thread = threading.Thread(target=self._work_loop, name=f'rank_{pp_rank}', daemon=True) self.main_loop_thread = threading.Thread(target=self._work_loop, name=f'rank_{pp_rank}', daemon=True)
self.main_loop_thread.start() self.main_loop_thread.start()
def _get_future_by_device(self): def _get_future_by_device(self):
return torch.futures.Future(devices=None if self.device in (None, 'cpu') else [self.device]) return torch.futures.Future(devices=None if self.device in (None, 'cpu') else [self.device])
def _initialize_outstanding_range(self, pp_rank: int, actual_stage_num: int, use_1F1B: bool) -> Tuple[int]: def _initialize_outstanding_range(self):
outstanding_range = None outstanding_range = None
if use_1F1B: if self.use_1F1B:
if pp_rank == actual_stage_num - 1: if self.pp_rank == self.actual_stage_num - 1:
outstanding_range = (0, 1) outstanding_range = (0, 1)
else: else:
outstanding_range = (actual_stage_num, actual_stage_num) outstanding_range = (self.actual_stage_num, self.actual_stage_num)
return outstanding_range self.outstanding_range = outstanding_range
def sync_global_worker_rrefs(self, pp_rank_to_worker_rref: Dict[int, PyRRef]) -> None: def sync_global_worker_rrefs(self, pp_rank_to_worker_rref: Dict[int, PyRRef]) -> None:
assert self.pp_rank_to_worker_rref is None, f"in rank {self.pp_rank}, worker has sync global workers rrefs" assert self.pp_rank_to_worker_rref is None, f"in rank {self.pp_rank}, worker has sync global workers rrefs"
@ -189,20 +213,16 @@ class Worker:
def get_output_by_key(self, key: UniqueKey) -> Any: def get_output_by_key(self, key: UniqueKey) -> Any:
with self.output_list_condition_lock: with self.output_list_condition_lock:
while key not in self.output_list: self.output_list_condition_lock.wait_for(lambda: key in self.output_list)
self.output_list_condition_lock.wait()
output_work_item = self.output_list[key] output_work_item = self.output_list[key]
output = output_work_item.output.wait() output = output_work_item.output.wait()
# color_debug(f'rank {self.pp_rank}, output {type(output)}', 'get output', 'red') # color_debug(f'rank {self.pp_rank}, output {type(output)}', 'get output', 'red')
output_work_item.refcount += 1 output_work_item.refcount += 1
# all consumers have been satisfied, the work_item can be released # all consumers have been satisfied, the work_item can be released
with self.output_list_condition_lock: with self.output_list_condition_lock:
if output_work_item.refcount == len(self.consumer_stage_ids): if output_work_item.refcount >= len(self.consumer_stage_ids):
self.output_list.pop(key) self.output_list.pop(key)
return output return output
def get_parameters(self) -> List[torch.Tensor]: def get_parameters(self) -> List[torch.Tensor]:
@ -211,34 +231,28 @@ class Worker:
def get_parameter_gradients(self) -> List[torch.Tensor]: def get_parameter_gradients(self) -> List[torch.Tensor]:
return [p.grad for p in self.module_partition.parameters()] return [p.grad for p in self.module_partition.parameters()]
def reset_pp_context(self):
self.forward_times = 0
self.backward_times = 0
self.outstanding = 0
self.microbatch_id_to_backward_cache.clear()
self.output_list.clear()
# just for first pp_rank # just for first pp_rank
def set_input(self, microbatch_id: int, microbatch: Tuple[Any], forward_only: bool): def set_input(self, microbatch_id: int, microbatch: Tuple[Any], forward_only: bool):
assert self.consumer_stage_ids is not None
key = UniqueKey(microbatch_id, Phase.FORWARD)
output = self._get_future_by_device()
args = [microbatch] if isinstance(microbatch, torch.Tensor) else microbatch
work_item = WorkItem(self.pp_rank, Phase.FORWARD, args, {}, output, microbatch_id, None, self.num_microbatches,
forward_only)
with self.work_list_condition_lock: with self.work_list_condition_lock:
assert self.consumer_stage_ids is not None
consumer_num = len(self.consumer_stage_ids)
key = UniqueKey(microbatch_id, Phase.FORWARD)
output = self._get_future_by_device()
args = [microbatch] if isinstance(microbatch, torch.Tensor) else microbatch
work_item = WorkItem(self.pp_rank, Phase.FORWARD, args, {}, output, microbatch_id, None,
self.num_microbatches, forward_only)
self.work_list[key] = work_item self.work_list[key] = work_item
color_debug(f'rank {self.pp_rank} receive data from dataloader', 'data dispatch', 'magenta') color_debug(f'rank {self.pp_rank} receive data from dataloader', 'data dispatch', 'magenta')
self.work_list_condition_lock.notify_all() self.work_list_condition_lock.notify_all()
# just for last pp_rank
def set_labels(self, microbatch_id: int, microlabels: Any):
self.microbatch_id_to_labels[microbatch_id] = microlabels
# just for last pp_rank # just for last pp_rank
def _begin_backward(self, microbatch_id: int): def _begin_backward(self, microbatch_id: int):
with self.work_list_condition_lock: with self.work_list_condition_lock:
assert self.producer_stage_ids is not None assert self.producer_stage_ids is not None
producer_num = len(self.producer_stage_ids)
key = UniqueKey(microbatch_id, Phase.BACKWARD) key = UniqueKey(microbatch_id, Phase.BACKWARD)
output = self._get_future_by_device() output = self._get_future_by_device()
grad_wrt_loss = torch.tensor(1, device=self.device) grad_wrt_loss = torch.tensor(1, device=self.device)
@ -272,15 +286,10 @@ class Worker:
color_debug(f'rank {self.pp_rank} get {len(subscribe_forward_futures)} futs from its producer', 'data dispatch', color_debug(f'rank {self.pp_rank} get {len(subscribe_forward_futures)} futs from its producer', 'data dispatch',
'magenta') 'magenta')
args = [] work_item_from_producer = WorkItem(stage_id, Phase.FORWARD, subscribe_forward_futures, {}, output,
for i in range(producer_num): microbatch_id, None, self.num_microbatches, forward_only)
producer_args = subscribe_forward_futures[i].wait()
args.extend(producer_args)
work_item_from_producer = WorkItem(stage_id, Phase.FORWARD, args, {}, output, microbatch_id, None, # color_debug(f'rank {self.pp_rank} get value {tensor_shape_list(args)} from fut', 'data dispatch', 'magenta')
self.num_microbatches, forward_only)
color_debug(f'rank {self.pp_rank} get value {tensor_shape_list(args)} from fut', 'data dispatch', 'magenta')
# add work_item to work_list # add work_item to work_list
with self.work_list_condition_lock: with self.work_list_condition_lock:
key = UniqueKey(microbatch_id, Phase.FORWARD) key = UniqueKey(microbatch_id, Phase.FORWARD)
@ -312,16 +321,11 @@ class Worker:
consumer_worker_rref = self.pp_rank_to_worker_rref[consumer_stage_id] consumer_worker_rref = self.pp_rank_to_worker_rref[consumer_stage_id]
subscribe_backward_futures[i] = consumer_worker_rref.rpc_async().get_output_by_key(consumer_output_key) subscribe_backward_futures[i] = consumer_worker_rref.rpc_async().get_output_by_key(consumer_output_key)
args = []
for i in range(consumer_num):
consumer_args = subscribe_backward_futures[i].wait()
args.extend(consumer_args)
# flatten args # flatten args
work_item_from_consumer = WorkItem(stage_id, Phase.BACKWARD, args, {}, output, microbatch_id, None, work_item_from_consumer = WorkItem(stage_id, Phase.BACKWARD, subscribe_backward_futures, {}, output,
self.num_microbatches, False) microbatch_id, None, self.num_microbatches, False)
color_debug(f'rank {self.pp_rank} get value {tensor_shape_list(args)} from fut', 'data dispatch', 'magenta') # color_debug(f'rank {self.pp_rank} get value {tensor_shape_list(args)} from fut', 'data dispatch', 'magenta')
# add work_item to work_list # add work_item to work_list
with self.work_list_condition_lock: with self.work_list_condition_lock:
@ -351,63 +355,50 @@ class Worker:
self.consumer_stage_ids.append(next_rank) self.consumer_stage_ids.append(next_rank)
def _get_work_item_key(self) -> UniqueKey: def _get_work_item_key(self) -> UniqueKey:
with self.work_list_condition_lock: # execute backward first (if backward phase in work_list)
while len(self.work_list) == 0: pp_rank = self.pp_rank
self.work_list_condition_lock.wait() actual_stage_num = self.actual_stage_num
num_microbatches = self.num_microbatches
# each stage must do Key(microbatch_id=0, phase=FORWARD) first is_last_stage = pp_rank == actual_stage_num - 1
# before doing the operation, reset the context first
if self.reset_key in self.work_list:
self.reset_pp_context()
# execute backward first (if backward phase in work_list)
pp_rank = self.pp_rank
actual_stage_num = self.actual_stage_num
num_microbatches = self.num_microbatches
is_last_stage = pp_rank == actual_stage_num - 1
select_work_list_key: UniqueKey = None
if self.outstanding_range:
if self.outstanding <= self.outstanding_range[0]:
target_phase = Phase.FORWARD
target_microbatch_id = self.forward_times
elif self.outstanding >= self.outstanding_range[1]:
target_phase = Phase.BACKWARD
target_microbatch_id = self.backward_times
else:
raise ValueError("outstanding_range[1] - outstanding_range[0] must be in [0, 1]")
target_key = UniqueKey(target_microbatch_id, target_phase)
if target_key in self.work_list:
select_work_list_key = target_key
# change outstanding_range at:
# 1. forward times reach actual_stage_num, this is the end of continuous forward
# 2. forward times reach num_microbatches, this is the end of 1F1B mode
if not is_last_stage and \
select_work_list_key is not None and \
select_work_list_key.phase == Phase.FORWARD:
if select_work_list_key.microbatch_id == actual_stage_num - 1:
outstanding_min = actual_stage_num - pp_rank - 1
outstanding_max = actual_stage_num - pp_rank
self.outstanding_range = (outstanding_min, outstanding_max)
elif select_work_list_key.microbatch_id == num_microbatches - 1:
self.outstanding_range = (0, 0)
if self.outstanding_range:
if self.outstanding <= self.outstanding_range[0]:
target_phase = Phase.FORWARD
target_microbatch_id = self.forward_times
elif self.outstanding >= self.outstanding_range[1]:
target_phase = Phase.BACKWARD
target_microbatch_id = self.backward_times
else: else:
if self.forward_times < num_microbatches: raise ValueError("outstanding_range[1] - outstanding_range[0] must be in [0, 1]")
target_phase = Phase.FORWARD
target_microbatch_id = self.forward_times
else:
target_phase = Phase.BACKWARD
target_microbatch_id = self.backward_times
target_key = UniqueKey(target_microbatch_id, target_phase) target_key = UniqueKey(target_microbatch_id, target_phase)
if target_key in self.work_list: # change outstanding_range at:
select_work_list_key = target_key # 1. forward times reach actual_stage_num, this is the end of continuous forward
# 2. forward times reach num_microbatches, this is the end of 1F1B mode
if not is_last_stage and \
target_key.phase == Phase.FORWARD:
if target_key.microbatch_id == actual_stage_num - 1:
outstanding_min = actual_stage_num - pp_rank - 1
outstanding_max = actual_stage_num - pp_rank
self.outstanding_range = (outstanding_min, outstanding_max)
elif target_key.microbatch_id == num_microbatches - 1:
self.outstanding_range = (0, 0)
return select_work_list_key else:
if self.forward_times < num_microbatches:
target_phase = Phase.FORWARD
target_microbatch_id = self.forward_times
else:
target_phase = Phase.BACKWARD
target_microbatch_id = self.backward_times
target_key = UniqueKey(target_microbatch_id, target_phase)
with self.work_list_condition_lock:
self.work_list_condition_lock.wait_for(lambda: target_key in self.work_list)
return target_key
def _consume_work_item_by_phase(self, work_item: WorkItem): def _consume_work_item_by_phase(self, work_item: WorkItem):
phase = work_item.phase phase = work_item.phase
@ -421,7 +412,7 @@ class Worker:
is_first_stage = (self.pp_rank == 0) is_first_stage = (self.pp_rank == 0)
is_last_stage = (self.pp_rank == self.actual_stage_num - 1) is_last_stage = (self.pp_rank == self.actual_stage_num - 1)
# if self.pp_rank == 3: # if self.pp_rank == 0:
# print( # print(
# f'I am rank_{self.pp_rank} microbatch_id : {microbatch_id} {phase} {self._get_store_len()} | {self.outstanding} {self.outstanding_range}' # f'I am rank_{self.pp_rank} microbatch_id : {microbatch_id} {phase} {self._get_store_len()} | {self.outstanding} {self.outstanding_range}'
# ) # )
@ -436,12 +427,7 @@ class Worker:
if not forward_only: if not forward_only:
self.outstanding += 1 self.outstanding += 1
args = get_real_args(args)
# TODO : more elegant ?
for i in range(len(args)):
arg_obj = args[i]
if isinstance(arg_obj, torch.Tensor) and not arg_obj.requires_grad:
args[i] = arg_obj.requires_grad_()
# last stage doesn't need to do checkpoint, for it will do backward instantly # last stage doesn't need to do checkpoint, for it will do backward instantly
if forward_only: if forward_only:
@ -458,7 +444,14 @@ class Worker:
use_checkpoint = True use_checkpoint = True
else: else:
consume_result = self.module_partition(*args, **kwargs) consume_result = self.module_partition(*args, **kwargs)
stage_outputs = consume_result if is_last_stage and self.criterion:
labels = self.microbatch_id_to_labels.pop(microbatch_id)
loss: torch.Tensor = self.criterion(consume_result, labels)
consume_result = loss.item()
else:
loss = consume_result
stage_outputs = loss
stage_inputs = args stage_inputs = args
use_checkpoint = False use_checkpoint = False
@ -467,7 +460,8 @@ class Worker:
stage_outputs, stage_outputs,
checkpoint=use_checkpoint) checkpoint=use_checkpoint)
consume_result = [consume_result] if isinstance(consume_result, torch.Tensor) else consume_result consume_result = [consume_result] if isinstance(consume_result,
(torch.Tensor, int, float)) else consume_result
# if not forward_only, do the backward # if not forward_only, do the backward
if not forward_only: if not forward_only:
@ -488,20 +482,21 @@ class Worker:
stage_outputs = backward_cache.stage_outputs stage_outputs = backward_cache.stage_outputs
stage_inputs = backward_cache.stage_inputs stage_inputs = backward_cache.stage_inputs
grad_tensors = args
use_checkpoint = backward_cache.checkpoint use_checkpoint = backward_cache.checkpoint
if use_checkpoint: if use_checkpoint:
stage_outputs = [self.module_partition(*stage_inputs)] stage_outputs = [self.module_partition(*stage_inputs)]
# overlap recompute and future.wait
grad_tensors = get_real_args(args)
autograd.backward(stage_outputs, grad_tensors=grad_tensors) autograd.backward(stage_outputs, grad_tensors=grad_tensors)
# collect grad of input tensor # collect grad of input tensor
consume_result = [] consume_result = []
for input_node in stage_inputs: if not is_first_stage:
if isinstance(input_node, torch.Tensor): for input_node in stage_inputs:
consume_result.append(input_node.grad) if isinstance(input_node, torch.Tensor):
consume_result.append(input_node.grad)
else: else:
raise TypeError(f"Unknown phase appears in _consume_work_item_by_phase {phase}") raise TypeError(f"Unknown phase appears in _consume_work_item_by_phase {phase}")
@ -509,7 +504,20 @@ class Worker:
return consume_result return consume_result
def _get_store_len(self): def _get_store_len(self):
return f'work_list:{len(self.work_list)} output_list:{len(self.output_list)} backward_cache:{len(self.microbatch_id_to_backward_cache)}' return f'work_list:{len(self.work_list)} output_list:{len(self.output_list)} backward_cache:{len(self.microbatch_id_to_backward_cache)} label_cache:{len(self.microbatch_id_to_labels)}'
def _get_parameter_grad_sum(self):
grad_sum = 0
for p in self.module_partition.parameters():
if p.grad is not None:
grad_sum += p.grad.sum()
return grad_sum
def _is_first_step(self, work_item) -> bool:
return work_item.phase == Phase.FORWARD and work_item.microbatch_id == 0
def _is_last_step(self, work_item) -> bool:
return work_item.phase == Phase.BACKWARD and work_item.microbatch_id == self.num_microbatches - 1
# do the main loop to consume ready_list # do the main loop to consume ready_list
def _work_loop(self): def _work_loop(self):
@ -519,8 +527,6 @@ class Worker:
# main loop # main loop
while True: while True:
work_item_key = self._get_work_item_key() work_item_key = self._get_work_item_key()
if work_item_key is None:
continue
# move current work item to output_list to activate subscribe in advance # move current work item to output_list to activate subscribe in advance
with self.work_list_condition_lock: with self.work_list_condition_lock:
@ -540,20 +546,32 @@ class Worker:
color_debug( color_debug(
f'rank_{self.pp_rank} [{work_item.phase}] finish consuming, result is {tensor_shape_list(consume_result)} {self._get_store_len()}', f'rank_{self.pp_rank} [{work_item.phase}] finish consuming, result is {tensor_shape_list(consume_result)} {self._get_store_len()}',
'work loop', 'green') 'work loop', 'green')
work_item.output.set_result(consume_result) work_item.output.set_result(consume_result)
# if is last step in one batch reset context and do step
if self._is_last_step(work_item):
if hasattr(self, 'optimizer'):
self.step()
self.forward_times = 0
self.backward_times = 0
self.outstanding = 0
self._initialize_outstanding_range()
def initialize_optimizer(self, optimizer_class: type, **kwargs): def initialize_optimizer(self, optimizer_class: type, **kwargs):
self.optimizer: optim.Optimizer = optimizer_class(self.module_partition.parameters(), **kwargs) self.optimizer: optim.Optimizer = optimizer_class(self.module_partition.parameters(), **kwargs)
def step(self): def wait_for_step(self):
assert hasattr(self, "optimizer"), "call initialize_optimizer first before you call step!" self.step_lock.acquire()
self.work_list.clear()
self.output_list.clear()
self.microbatch_id_to_backward_cache.clear()
def step(self):
# print(f'rank_{self.pp_rank}', sum([p.sum() for p in self.module_partition.parameters()]))
self.optimizer.step() self.optimizer.step()
# print(f'rank_{self.pp_rank}', sum([p.sum() for p in self.module_partition.parameters()]))
self.optimizer.zero_grad() self.optimizer.zero_grad()
self.step_lock.release()
class PipelineEngineBase(ABC, nn.Module): class PipelineEngineBase(ABC, nn.Module):
@ -564,10 +582,12 @@ class PipelineEngineBase(ABC, nn.Module):
device: str, device: str,
use_1F1B=False, use_1F1B=False,
chunk: int = 1, chunk: int = 1,
criterion: Callable = None,
checkpoint: bool = False) -> None: checkpoint: bool = False) -> None:
super().__init__() super().__init__()
self.module_partitions: List[nn.Module] = module_partitions self.module_partitions: List[nn.Module] = module_partitions
self.chunk = chunk self.chunk = chunk
self.criterion = criterion
self.num_microbatches = num_microbatches self.num_microbatches = num_microbatches
self.device = device self.device = device
self.use_1F1B = use_1F1B self.use_1F1B = use_1F1B
@ -577,6 +597,8 @@ class PipelineEngineBase(ABC, nn.Module):
self.pp_rank_to_worker_rref: Dict[int, PyRRef] = dict() self.pp_rank_to_worker_rref: Dict[int, PyRRef] = dict()
self.step_futs: List[Future] = []
self._check_argument() self._check_argument()
self._create_pp_rank_to_rpc_worker_id() self._create_pp_rank_to_rpc_worker_id()
self._init_worker() self._init_worker()
@ -613,6 +635,7 @@ class PipelineEngineBase(ABC, nn.Module):
checkpoint = self.checkpoint checkpoint = self.checkpoint
num_microbatches = self.num_microbatches num_microbatches = self.num_microbatches
device = self.device device = self.device
criterion = self.criterion
for pp_rank in range(actual_stage_num): for pp_rank in range(actual_stage_num):
module_partition = self.module_partitions[pp_rank] module_partition = self.module_partitions[pp_rank]
@ -622,7 +645,8 @@ class PipelineEngineBase(ABC, nn.Module):
self.pp_rank_to_worker_rref[pp_rank] = rpc.remote(rpc_worker_id, self.pp_rank_to_worker_rref[pp_rank] = rpc.remote(rpc_worker_id,
Worker, Worker,
args=(module_partition, pp_rank, actual_stage_num, args=(module_partition, pp_rank, actual_stage_num,
num_microbatches, use_1F1B, device, checkpoint)) num_microbatches, use_1F1B, device, criterion,
checkpoint))
# let each worker know global worker rref (include itself) # let each worker know global worker rref (include itself)
for pp_rank in range(actual_stage_num): for pp_rank in range(actual_stage_num):
@ -646,12 +670,15 @@ class PipelineEngineBase(ABC, nn.Module):
grads[stage_id].append(grad) grads[stage_id].append(grad)
return grads return grads
def forward_backward(self, batch: torch.Tensor, forward_only: bool = False): def forward_backward(self, batch: torch.Tensor, labels: torch.Tensor = None, forward_only: bool = False):
if labels is not None:
assert len(batch) == len(labels)
num_microbatches = self.num_microbatches num_microbatches = self.num_microbatches
microbatch_size = len(batch) // num_microbatches microbatch_size = len(batch) // num_microbatches
actual_stage_num = self._get_actual_stage_num() actual_stage_num = self._get_actual_stage_num()
first_stage_worker = self.pp_rank_to_worker_rref[0] first_worker_rref = self.pp_rank_to_worker_rref[0]
last_worker_rref = self.pp_rank_to_worker_rref[actual_stage_num - 1] last_worker_rref = self.pp_rank_to_worker_rref[actual_stage_num - 1]
microbatch_iter = range(num_microbatches) microbatch_iter = range(num_microbatches)
@ -659,11 +686,7 @@ class PipelineEngineBase(ABC, nn.Module):
microbatch_iter = tqdm(microbatch_iter) microbatch_iter = tqdm(microbatch_iter)
ret_future: List[Future] = [None] * num_microbatches ret_future: List[Future] = [None] * num_microbatches
from time import sleep
for microbatch_id in microbatch_iter: for microbatch_id in microbatch_iter:
microbatch = batch[microbatch_size * microbatch_id:microbatch_size * (microbatch_id + 1)]
# control data input speed # control data input speed
# to prevent exceed of wait limitations # to prevent exceed of wait limitations
if microbatch_id >= actual_stage_num: if microbatch_id >= actual_stage_num:
@ -671,15 +694,27 @@ class PipelineEngineBase(ABC, nn.Module):
ret_future[microbatch_id - actual_stage_num].wait() ret_future[microbatch_id - actual_stage_num].wait()
else: else:
key = UniqueKey(microbatch_id - actual_stage_num, Phase.BACKWARD) key = UniqueKey(microbatch_id - actual_stage_num, Phase.BACKWARD)
first_stage_worker.rpc_sync().get_output_by_key(key) first_worker_rref.rpc_sync().get_output_by_key(key)
# run one microbatch # set input
first_stage_worker.rpc_sync().set_input(microbatch_id, microbatch, forward_only) microbatch = batch[microbatch_size * microbatch_id:microbatch_size * (microbatch_id + 1)]
microbatch = microbatch.cuda()
first_worker_rref.remote().set_input(microbatch_id, microbatch, forward_only)
# set labels
if not forward_only and labels is not None:
microlabels = labels[microbatch_size * microbatch_id:microbatch_size * (microbatch_id + 1)]
microlabels = microlabels.cuda()
last_worker_rref.remote().set_labels(microbatch_id, microlabels)
key = UniqueKey(microbatch_id, Phase.FORWARD) key = UniqueKey(microbatch_id, Phase.FORWARD)
ret_future[microbatch_id] = last_worker_rref.rpc_async().get_output_by_key(key) ret_future[microbatch_id] = last_worker_rref.rpc_async().get_output_by_key(key)
# wait forward # wait for last backward in rank0
if not forward_only:
key = UniqueKey(self.num_microbatches - 1, Phase.BACKWARD)
first_worker_rref.rpc_sync().get_output_by_key(key)
# collect forward result
# TODO : all the node to output # TODO : all the node to output
forward_result = None forward_result = None
@ -691,28 +726,30 @@ class PipelineEngineBase(ABC, nn.Module):
for i in range(len(forward_result)): for i in range(len(forward_result)):
forward_result[i].append(ret[i]) forward_result[i].append(ret[i])
# wait for last backward in rank0 if hasattr(self, 'optimizer_class'):
if not forward_only: # wait for all step
key = UniqueKey(self.num_microbatches - 1, Phase.BACKWARD) # TODO : more elegant ?
first_stage_worker.rpc_sync().get_output_by_key(key) for pp_rank in self.pp_rank_to_worker_rref:
worker_rref = self.pp_rank_to_worker_rref[pp_rank]
worker_rref.rpc_sync().wait_for_step()
return forward_result return forward_result
def initialize_optimizer(self, optimizer_class: type, **kwargs): def initialize_optimizer(self, optimizer_class: type, **kwargs):
actual_stage_num = self._get_actual_stage_num() actual_stage_num = self._get_actual_stage_num()
self.optimizer_class = optimizer_class
for pp_rank in range(actual_stage_num): for pp_rank in range(actual_stage_num):
worker_rref = self.pp_rank_to_worker_rref[pp_rank] worker_rref = self.pp_rank_to_worker_rref[pp_rank]
worker_rref.rpc_sync().initialize_optimizer(optimizer_class, **kwargs) worker_rref.remote().initialize_optimizer(optimizer_class, **kwargs)
def step(self): def step(self):
step_futs: List[Future] = []
actual_stage_num = self._get_actual_stage_num() actual_stage_num = self._get_actual_stage_num()
for pp_rank in range(actual_stage_num): for pp_rank in range(actual_stage_num):
worker_rref = self.pp_rank_to_worker_rref[pp_rank] worker_rref = self.pp_rank_to_worker_rref[pp_rank]
fut = worker_rref.rpc_async().step() fut = worker_rref.rpc_async().step()
step_futs.append(fut) self.step_futs.append(fut)
# wait for all optimizers for fut in self.step_futs:
for fut in step_futs:
fut.wait() fut.wait()
@ -724,9 +761,10 @@ class FillDrainPipelineEngine(PipelineEngineBase):
num_microbatches: int, num_microbatches: int,
device: str, device: str,
chunk: int = 1, chunk: int = 1,
criterion: Callable = None,
checkpoint: bool = False) -> None: checkpoint: bool = False) -> None:
use_1F1B = False use_1F1B = False
super().__init__(module_partitions, stage_num, num_microbatches, device, use_1F1B, chunk, checkpoint) super().__init__(module_partitions, stage_num, num_microbatches, device, use_1F1B, chunk, criterion, checkpoint)
class OneFOneBPipelineEngine(PipelineEngineBase): class OneFOneBPipelineEngine(PipelineEngineBase):
@ -737,6 +775,7 @@ class OneFOneBPipelineEngine(PipelineEngineBase):
num_microbatches: int, num_microbatches: int,
device: str, device: str,
chunk: int = 1, chunk: int = 1,
criterion: Callable = None,
checkpoint: bool = False) -> None: checkpoint: bool = False) -> None:
use_1F1B = True use_1F1B = True
super().__init__(module_partitions, stage_num, num_microbatches, device, use_1F1B, chunk, checkpoint) super().__init__(module_partitions, stage_num, num_microbatches, device, use_1F1B, chunk, criterion, checkpoint)

View File

@ -45,6 +45,7 @@ def parse_args():
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('--epoch', type=int, default=1) parser.add_argument('--epoch', type=int, default=1)
parser.add_argument('--world_size', type=int, default=2) parser.add_argument('--world_size', type=int, default=2)
parser.add_argument('--batch_size', type=int, default=16)
parser.add_argument('--num_microbatches', type=int, default=2) parser.add_argument('--num_microbatches', type=int, default=2)
parser.add_argument('--chunk', type=int, default=1) parser.add_argument('--chunk', type=int, default=1)
parser.add_argument('--use_checkpoint', action='store_true') parser.add_argument('--use_checkpoint', action='store_true')

View File

@ -43,7 +43,6 @@ def run_master(args):
engine.initialize_optimizer(optimizer_class, lr=lr) engine.initialize_optimizer(optimizer_class, lr=lr)
_ = engine.forward_backward(input_sample) _ = engine.forward_backward(input_sample)
engine.step()
cuda_rpc_result = [] cuda_rpc_result = []
single_result = [] single_result = []