mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-06-26 07:22:12 +00:00
[pipeline/tuning] improve dispatch performance both time and space cost (#1544)
This commit is contained in:
parent
4f59693207
commit
6159d45417
@ -1,7 +1,8 @@
|
|||||||
import threading
|
import threading
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import List, Any, Tuple, Dict
|
from typing import List, Any, Tuple, Dict, Callable
|
||||||
from abc import ABC
|
from abc import ABC
|
||||||
|
import sys
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
@ -11,6 +12,7 @@ from torch._C._distributed_rpc import PyRRef
|
|||||||
from torch import autograd
|
from torch import autograd
|
||||||
from torch import optim
|
from torch import optim
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
from time import time
|
||||||
|
|
||||||
from colorama import Back, Style
|
from colorama import Back, Style
|
||||||
|
|
||||||
@ -30,6 +32,10 @@ def color_debug(text, prefix=' ', color='blue'):
|
|||||||
|
|
||||||
|
|
||||||
def tensor_shape_list(tensors):
|
def tensor_shape_list(tensors):
|
||||||
|
if tensors is None:
|
||||||
|
return None
|
||||||
|
if isinstance(tensors, (int, float)):
|
||||||
|
return tensors
|
||||||
if isinstance(tensors, torch.Tensor):
|
if isinstance(tensors, torch.Tensor):
|
||||||
return tensors.shape
|
return tensors.shape
|
||||||
shapes = []
|
shapes = []
|
||||||
@ -41,6 +47,25 @@ def tensor_shape_list(tensors):
|
|||||||
return shapes
|
return shapes
|
||||||
|
|
||||||
|
|
||||||
|
def get_real_args(args):
|
||||||
|
if isinstance(args, torch.Tensor):
|
||||||
|
return args
|
||||||
|
elif isinstance(args, list):
|
||||||
|
real_args = []
|
||||||
|
for arg in args:
|
||||||
|
if isinstance(arg, Future):
|
||||||
|
value = arg.wait()
|
||||||
|
else:
|
||||||
|
value = arg
|
||||||
|
if isinstance(value, list):
|
||||||
|
real_args.extend(value)
|
||||||
|
else:
|
||||||
|
real_args.append(value)
|
||||||
|
return real_args
|
||||||
|
else:
|
||||||
|
raise TypeError(f"Expect receive tensor or list, but receive {type(args)}")
|
||||||
|
|
||||||
|
|
||||||
class Phase(Enum):
|
class Phase(Enum):
|
||||||
FORWARD = 0
|
FORWARD = 0
|
||||||
BACKWARD = 1
|
BACKWARD = 1
|
||||||
@ -112,18 +137,6 @@ class BackwardCache:
|
|||||||
setattr(self, arg_name, locals()[arg_name])
|
setattr(self, arg_name, locals()[arg_name])
|
||||||
|
|
||||||
|
|
||||||
class RemoteExecutor:
|
|
||||||
|
|
||||||
def __init__(self) -> None:
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
class RemoteOptimizer:
|
|
||||||
|
|
||||||
def __init__(self) -> None:
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
class Worker:
|
class Worker:
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
@ -133,6 +146,7 @@ class Worker:
|
|||||||
num_microbatches: int,
|
num_microbatches: int,
|
||||||
use_1F1B: bool,
|
use_1F1B: bool,
|
||||||
device: str,
|
device: str,
|
||||||
|
criterion: Callable = None,
|
||||||
checkpoint: bool = False) -> None:
|
checkpoint: bool = False) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.pp_rank = pp_rank
|
self.pp_rank = pp_rank
|
||||||
@ -140,7 +154,8 @@ class Worker:
|
|||||||
self.num_microbatches = num_microbatches
|
self.num_microbatches = num_microbatches
|
||||||
self.checkpoint = checkpoint
|
self.checkpoint = checkpoint
|
||||||
self.device = device
|
self.device = device
|
||||||
self.outstanding_range = self._initialize_outstanding_range(pp_rank, actual_stage_num, use_1F1B)
|
self.use_1F1B = use_1F1B
|
||||||
|
self._initialize_outstanding_range()
|
||||||
|
|
||||||
# variable and const for context managment
|
# variable and const for context managment
|
||||||
self.outstanding = 0
|
self.outstanding = 0
|
||||||
@ -157,30 +172,39 @@ class Worker:
|
|||||||
|
|
||||||
# module partitions
|
# module partitions
|
||||||
self.module_partition = module_partition.to(device)
|
self.module_partition = module_partition.to(device)
|
||||||
|
if criterion:
|
||||||
|
assert callable(criterion)
|
||||||
|
self.criterion = criterion
|
||||||
|
|
||||||
# container to maintain loop
|
# container to maintain loop
|
||||||
self.microbatch_id_to_backward_cache: Dict[int, BackwardCache] = dict()
|
self.microbatch_id_to_backward_cache: Dict[int, BackwardCache] = dict()
|
||||||
|
self.microbatch_id_to_labels: Dict[int, Any] = dict()
|
||||||
self.work_list: Dict[UniqueKey, WorkItem] = dict()
|
self.work_list: Dict[UniqueKey, WorkItem] = dict()
|
||||||
self.output_list: Dict[UniqueKey, WorkItem] = dict()
|
self.output_list: Dict[UniqueKey, WorkItem] = dict()
|
||||||
|
|
||||||
# lock for the list
|
# lock for the list
|
||||||
self.work_list_condition_lock = threading.Condition(threading.Lock())
|
self.work_list_condition_lock = threading.Condition(threading.Lock())
|
||||||
self.output_list_condition_lock = threading.Condition(threading.Lock())
|
self.output_list_condition_lock = threading.Condition(threading.Lock())
|
||||||
|
self.label_lock = threading.Condition(threading.Lock())
|
||||||
|
|
||||||
|
self.step_lock = threading.Lock()
|
||||||
|
self.step_lock.acquire()
|
||||||
|
|
||||||
|
# main loop
|
||||||
self.main_loop_thread = threading.Thread(target=self._work_loop, name=f'rank_{pp_rank}', daemon=True)
|
self.main_loop_thread = threading.Thread(target=self._work_loop, name=f'rank_{pp_rank}', daemon=True)
|
||||||
self.main_loop_thread.start()
|
self.main_loop_thread.start()
|
||||||
|
|
||||||
def _get_future_by_device(self):
|
def _get_future_by_device(self):
|
||||||
return torch.futures.Future(devices=None if self.device in (None, 'cpu') else [self.device])
|
return torch.futures.Future(devices=None if self.device in (None, 'cpu') else [self.device])
|
||||||
|
|
||||||
def _initialize_outstanding_range(self, pp_rank: int, actual_stage_num: int, use_1F1B: bool) -> Tuple[int]:
|
def _initialize_outstanding_range(self):
|
||||||
outstanding_range = None
|
outstanding_range = None
|
||||||
if use_1F1B:
|
if self.use_1F1B:
|
||||||
if pp_rank == actual_stage_num - 1:
|
if self.pp_rank == self.actual_stage_num - 1:
|
||||||
outstanding_range = (0, 1)
|
outstanding_range = (0, 1)
|
||||||
else:
|
else:
|
||||||
outstanding_range = (actual_stage_num, actual_stage_num)
|
outstanding_range = (self.actual_stage_num, self.actual_stage_num)
|
||||||
return outstanding_range
|
self.outstanding_range = outstanding_range
|
||||||
|
|
||||||
def sync_global_worker_rrefs(self, pp_rank_to_worker_rref: Dict[int, PyRRef]) -> None:
|
def sync_global_worker_rrefs(self, pp_rank_to_worker_rref: Dict[int, PyRRef]) -> None:
|
||||||
assert self.pp_rank_to_worker_rref is None, f"in rank {self.pp_rank}, worker has sync global workers rrefs"
|
assert self.pp_rank_to_worker_rref is None, f"in rank {self.pp_rank}, worker has sync global workers rrefs"
|
||||||
@ -189,20 +213,16 @@ class Worker:
|
|||||||
|
|
||||||
def get_output_by_key(self, key: UniqueKey) -> Any:
|
def get_output_by_key(self, key: UniqueKey) -> Any:
|
||||||
with self.output_list_condition_lock:
|
with self.output_list_condition_lock:
|
||||||
while key not in self.output_list:
|
self.output_list_condition_lock.wait_for(lambda: key in self.output_list)
|
||||||
self.output_list_condition_lock.wait()
|
|
||||||
|
|
||||||
output_work_item = self.output_list[key]
|
output_work_item = self.output_list[key]
|
||||||
|
|
||||||
output = output_work_item.output.wait()
|
output = output_work_item.output.wait()
|
||||||
# color_debug(f'rank {self.pp_rank}, output {type(output)}', 'get output', 'red')
|
# color_debug(f'rank {self.pp_rank}, output {type(output)}', 'get output', 'red')
|
||||||
output_work_item.refcount += 1
|
output_work_item.refcount += 1
|
||||||
|
|
||||||
# all consumers have been satisfied, the work_item can be released
|
# all consumers have been satisfied, the work_item can be released
|
||||||
with self.output_list_condition_lock:
|
with self.output_list_condition_lock:
|
||||||
if output_work_item.refcount == len(self.consumer_stage_ids):
|
if output_work_item.refcount >= len(self.consumer_stage_ids):
|
||||||
self.output_list.pop(key)
|
self.output_list.pop(key)
|
||||||
|
|
||||||
return output
|
return output
|
||||||
|
|
||||||
def get_parameters(self) -> List[torch.Tensor]:
|
def get_parameters(self) -> List[torch.Tensor]:
|
||||||
@ -211,34 +231,28 @@ class Worker:
|
|||||||
def get_parameter_gradients(self) -> List[torch.Tensor]:
|
def get_parameter_gradients(self) -> List[torch.Tensor]:
|
||||||
return [p.grad for p in self.module_partition.parameters()]
|
return [p.grad for p in self.module_partition.parameters()]
|
||||||
|
|
||||||
def reset_pp_context(self):
|
|
||||||
self.forward_times = 0
|
|
||||||
self.backward_times = 0
|
|
||||||
self.outstanding = 0
|
|
||||||
self.microbatch_id_to_backward_cache.clear()
|
|
||||||
self.output_list.clear()
|
|
||||||
|
|
||||||
# just for first pp_rank
|
# just for first pp_rank
|
||||||
def set_input(self, microbatch_id: int, microbatch: Tuple[Any], forward_only: bool):
|
def set_input(self, microbatch_id: int, microbatch: Tuple[Any], forward_only: bool):
|
||||||
|
assert self.consumer_stage_ids is not None
|
||||||
|
key = UniqueKey(microbatch_id, Phase.FORWARD)
|
||||||
|
output = self._get_future_by_device()
|
||||||
|
args = [microbatch] if isinstance(microbatch, torch.Tensor) else microbatch
|
||||||
|
work_item = WorkItem(self.pp_rank, Phase.FORWARD, args, {}, output, microbatch_id, None, self.num_microbatches,
|
||||||
|
forward_only)
|
||||||
with self.work_list_condition_lock:
|
with self.work_list_condition_lock:
|
||||||
assert self.consumer_stage_ids is not None
|
|
||||||
consumer_num = len(self.consumer_stage_ids)
|
|
||||||
key = UniqueKey(microbatch_id, Phase.FORWARD)
|
|
||||||
output = self._get_future_by_device()
|
|
||||||
args = [microbatch] if isinstance(microbatch, torch.Tensor) else microbatch
|
|
||||||
|
|
||||||
work_item = WorkItem(self.pp_rank, Phase.FORWARD, args, {}, output, microbatch_id, None,
|
|
||||||
self.num_microbatches, forward_only)
|
|
||||||
self.work_list[key] = work_item
|
self.work_list[key] = work_item
|
||||||
|
|
||||||
color_debug(f'rank {self.pp_rank} receive data from dataloader', 'data dispatch', 'magenta')
|
color_debug(f'rank {self.pp_rank} receive data from dataloader', 'data dispatch', 'magenta')
|
||||||
self.work_list_condition_lock.notify_all()
|
self.work_list_condition_lock.notify_all()
|
||||||
|
|
||||||
|
# just for last pp_rank
|
||||||
|
def set_labels(self, microbatch_id: int, microlabels: Any):
|
||||||
|
self.microbatch_id_to_labels[microbatch_id] = microlabels
|
||||||
|
|
||||||
# just for last pp_rank
|
# just for last pp_rank
|
||||||
def _begin_backward(self, microbatch_id: int):
|
def _begin_backward(self, microbatch_id: int):
|
||||||
with self.work_list_condition_lock:
|
with self.work_list_condition_lock:
|
||||||
assert self.producer_stage_ids is not None
|
assert self.producer_stage_ids is not None
|
||||||
producer_num = len(self.producer_stage_ids)
|
|
||||||
key = UniqueKey(microbatch_id, Phase.BACKWARD)
|
key = UniqueKey(microbatch_id, Phase.BACKWARD)
|
||||||
output = self._get_future_by_device()
|
output = self._get_future_by_device()
|
||||||
grad_wrt_loss = torch.tensor(1, device=self.device)
|
grad_wrt_loss = torch.tensor(1, device=self.device)
|
||||||
@ -272,15 +286,10 @@ class Worker:
|
|||||||
color_debug(f'rank {self.pp_rank} get {len(subscribe_forward_futures)} futs from its producer', 'data dispatch',
|
color_debug(f'rank {self.pp_rank} get {len(subscribe_forward_futures)} futs from its producer', 'data dispatch',
|
||||||
'magenta')
|
'magenta')
|
||||||
|
|
||||||
args = []
|
work_item_from_producer = WorkItem(stage_id, Phase.FORWARD, subscribe_forward_futures, {}, output,
|
||||||
for i in range(producer_num):
|
microbatch_id, None, self.num_microbatches, forward_only)
|
||||||
producer_args = subscribe_forward_futures[i].wait()
|
|
||||||
args.extend(producer_args)
|
|
||||||
|
|
||||||
work_item_from_producer = WorkItem(stage_id, Phase.FORWARD, args, {}, output, microbatch_id, None,
|
# color_debug(f'rank {self.pp_rank} get value {tensor_shape_list(args)} from fut', 'data dispatch', 'magenta')
|
||||||
self.num_microbatches, forward_only)
|
|
||||||
|
|
||||||
color_debug(f'rank {self.pp_rank} get value {tensor_shape_list(args)} from fut', 'data dispatch', 'magenta')
|
|
||||||
# add work_item to work_list
|
# add work_item to work_list
|
||||||
with self.work_list_condition_lock:
|
with self.work_list_condition_lock:
|
||||||
key = UniqueKey(microbatch_id, Phase.FORWARD)
|
key = UniqueKey(microbatch_id, Phase.FORWARD)
|
||||||
@ -312,16 +321,11 @@ class Worker:
|
|||||||
consumer_worker_rref = self.pp_rank_to_worker_rref[consumer_stage_id]
|
consumer_worker_rref = self.pp_rank_to_worker_rref[consumer_stage_id]
|
||||||
subscribe_backward_futures[i] = consumer_worker_rref.rpc_async().get_output_by_key(consumer_output_key)
|
subscribe_backward_futures[i] = consumer_worker_rref.rpc_async().get_output_by_key(consumer_output_key)
|
||||||
|
|
||||||
args = []
|
|
||||||
for i in range(consumer_num):
|
|
||||||
consumer_args = subscribe_backward_futures[i].wait()
|
|
||||||
args.extend(consumer_args)
|
|
||||||
|
|
||||||
# flatten args
|
# flatten args
|
||||||
work_item_from_consumer = WorkItem(stage_id, Phase.BACKWARD, args, {}, output, microbatch_id, None,
|
work_item_from_consumer = WorkItem(stage_id, Phase.BACKWARD, subscribe_backward_futures, {}, output,
|
||||||
self.num_microbatches, False)
|
microbatch_id, None, self.num_microbatches, False)
|
||||||
|
|
||||||
color_debug(f'rank {self.pp_rank} get value {tensor_shape_list(args)} from fut', 'data dispatch', 'magenta')
|
# color_debug(f'rank {self.pp_rank} get value {tensor_shape_list(args)} from fut', 'data dispatch', 'magenta')
|
||||||
|
|
||||||
# add work_item to work_list
|
# add work_item to work_list
|
||||||
with self.work_list_condition_lock:
|
with self.work_list_condition_lock:
|
||||||
@ -351,63 +355,50 @@ class Worker:
|
|||||||
self.consumer_stage_ids.append(next_rank)
|
self.consumer_stage_ids.append(next_rank)
|
||||||
|
|
||||||
def _get_work_item_key(self) -> UniqueKey:
|
def _get_work_item_key(self) -> UniqueKey:
|
||||||
with self.work_list_condition_lock:
|
# execute backward first (if backward phase in work_list)
|
||||||
while len(self.work_list) == 0:
|
pp_rank = self.pp_rank
|
||||||
self.work_list_condition_lock.wait()
|
actual_stage_num = self.actual_stage_num
|
||||||
|
num_microbatches = self.num_microbatches
|
||||||
# each stage must do Key(microbatch_id=0, phase=FORWARD) first
|
is_last_stage = pp_rank == actual_stage_num - 1
|
||||||
# before doing the operation, reset the context first
|
|
||||||
if self.reset_key in self.work_list:
|
|
||||||
self.reset_pp_context()
|
|
||||||
|
|
||||||
# execute backward first (if backward phase in work_list)
|
|
||||||
pp_rank = self.pp_rank
|
|
||||||
actual_stage_num = self.actual_stage_num
|
|
||||||
num_microbatches = self.num_microbatches
|
|
||||||
is_last_stage = pp_rank == actual_stage_num - 1
|
|
||||||
select_work_list_key: UniqueKey = None
|
|
||||||
|
|
||||||
if self.outstanding_range:
|
|
||||||
if self.outstanding <= self.outstanding_range[0]:
|
|
||||||
target_phase = Phase.FORWARD
|
|
||||||
target_microbatch_id = self.forward_times
|
|
||||||
elif self.outstanding >= self.outstanding_range[1]:
|
|
||||||
target_phase = Phase.BACKWARD
|
|
||||||
target_microbatch_id = self.backward_times
|
|
||||||
else:
|
|
||||||
raise ValueError("outstanding_range[1] - outstanding_range[0] must be in [0, 1]")
|
|
||||||
|
|
||||||
target_key = UniqueKey(target_microbatch_id, target_phase)
|
|
||||||
if target_key in self.work_list:
|
|
||||||
select_work_list_key = target_key
|
|
||||||
|
|
||||||
# change outstanding_range at:
|
|
||||||
# 1. forward times reach actual_stage_num, this is the end of continuous forward
|
|
||||||
# 2. forward times reach num_microbatches, this is the end of 1F1B mode
|
|
||||||
if not is_last_stage and \
|
|
||||||
select_work_list_key is not None and \
|
|
||||||
select_work_list_key.phase == Phase.FORWARD:
|
|
||||||
if select_work_list_key.microbatch_id == actual_stage_num - 1:
|
|
||||||
outstanding_min = actual_stage_num - pp_rank - 1
|
|
||||||
outstanding_max = actual_stage_num - pp_rank
|
|
||||||
self.outstanding_range = (outstanding_min, outstanding_max)
|
|
||||||
elif select_work_list_key.microbatch_id == num_microbatches - 1:
|
|
||||||
self.outstanding_range = (0, 0)
|
|
||||||
|
|
||||||
|
if self.outstanding_range:
|
||||||
|
if self.outstanding <= self.outstanding_range[0]:
|
||||||
|
target_phase = Phase.FORWARD
|
||||||
|
target_microbatch_id = self.forward_times
|
||||||
|
elif self.outstanding >= self.outstanding_range[1]:
|
||||||
|
target_phase = Phase.BACKWARD
|
||||||
|
target_microbatch_id = self.backward_times
|
||||||
else:
|
else:
|
||||||
if self.forward_times < num_microbatches:
|
raise ValueError("outstanding_range[1] - outstanding_range[0] must be in [0, 1]")
|
||||||
target_phase = Phase.FORWARD
|
|
||||||
target_microbatch_id = self.forward_times
|
|
||||||
else:
|
|
||||||
target_phase = Phase.BACKWARD
|
|
||||||
target_microbatch_id = self.backward_times
|
|
||||||
|
|
||||||
target_key = UniqueKey(target_microbatch_id, target_phase)
|
target_key = UniqueKey(target_microbatch_id, target_phase)
|
||||||
|
|
||||||
if target_key in self.work_list:
|
# change outstanding_range at:
|
||||||
select_work_list_key = target_key
|
# 1. forward times reach actual_stage_num, this is the end of continuous forward
|
||||||
|
# 2. forward times reach num_microbatches, this is the end of 1F1B mode
|
||||||
|
if not is_last_stage and \
|
||||||
|
target_key.phase == Phase.FORWARD:
|
||||||
|
if target_key.microbatch_id == actual_stage_num - 1:
|
||||||
|
outstanding_min = actual_stage_num - pp_rank - 1
|
||||||
|
outstanding_max = actual_stage_num - pp_rank
|
||||||
|
self.outstanding_range = (outstanding_min, outstanding_max)
|
||||||
|
elif target_key.microbatch_id == num_microbatches - 1:
|
||||||
|
self.outstanding_range = (0, 0)
|
||||||
|
|
||||||
return select_work_list_key
|
else:
|
||||||
|
if self.forward_times < num_microbatches:
|
||||||
|
target_phase = Phase.FORWARD
|
||||||
|
target_microbatch_id = self.forward_times
|
||||||
|
else:
|
||||||
|
target_phase = Phase.BACKWARD
|
||||||
|
target_microbatch_id = self.backward_times
|
||||||
|
|
||||||
|
target_key = UniqueKey(target_microbatch_id, target_phase)
|
||||||
|
|
||||||
|
with self.work_list_condition_lock:
|
||||||
|
self.work_list_condition_lock.wait_for(lambda: target_key in self.work_list)
|
||||||
|
|
||||||
|
return target_key
|
||||||
|
|
||||||
def _consume_work_item_by_phase(self, work_item: WorkItem):
|
def _consume_work_item_by_phase(self, work_item: WorkItem):
|
||||||
phase = work_item.phase
|
phase = work_item.phase
|
||||||
@ -421,7 +412,7 @@ class Worker:
|
|||||||
is_first_stage = (self.pp_rank == 0)
|
is_first_stage = (self.pp_rank == 0)
|
||||||
is_last_stage = (self.pp_rank == self.actual_stage_num - 1)
|
is_last_stage = (self.pp_rank == self.actual_stage_num - 1)
|
||||||
|
|
||||||
# if self.pp_rank == 3:
|
# if self.pp_rank == 0:
|
||||||
# print(
|
# print(
|
||||||
# f'I am rank_{self.pp_rank} microbatch_id : {microbatch_id} {phase} {self._get_store_len()} | {self.outstanding} {self.outstanding_range}'
|
# f'I am rank_{self.pp_rank} microbatch_id : {microbatch_id} {phase} {self._get_store_len()} | {self.outstanding} {self.outstanding_range}'
|
||||||
# )
|
# )
|
||||||
@ -436,12 +427,7 @@ class Worker:
|
|||||||
|
|
||||||
if not forward_only:
|
if not forward_only:
|
||||||
self.outstanding += 1
|
self.outstanding += 1
|
||||||
|
args = get_real_args(args)
|
||||||
# TODO : more elegant ?
|
|
||||||
for i in range(len(args)):
|
|
||||||
arg_obj = args[i]
|
|
||||||
if isinstance(arg_obj, torch.Tensor) and not arg_obj.requires_grad:
|
|
||||||
args[i] = arg_obj.requires_grad_()
|
|
||||||
|
|
||||||
# last stage doesn't need to do checkpoint, for it will do backward instantly
|
# last stage doesn't need to do checkpoint, for it will do backward instantly
|
||||||
if forward_only:
|
if forward_only:
|
||||||
@ -458,7 +444,14 @@ class Worker:
|
|||||||
use_checkpoint = True
|
use_checkpoint = True
|
||||||
else:
|
else:
|
||||||
consume_result = self.module_partition(*args, **kwargs)
|
consume_result = self.module_partition(*args, **kwargs)
|
||||||
stage_outputs = consume_result
|
if is_last_stage and self.criterion:
|
||||||
|
labels = self.microbatch_id_to_labels.pop(microbatch_id)
|
||||||
|
loss: torch.Tensor = self.criterion(consume_result, labels)
|
||||||
|
consume_result = loss.item()
|
||||||
|
else:
|
||||||
|
loss = consume_result
|
||||||
|
|
||||||
|
stage_outputs = loss
|
||||||
stage_inputs = args
|
stage_inputs = args
|
||||||
use_checkpoint = False
|
use_checkpoint = False
|
||||||
|
|
||||||
@ -467,7 +460,8 @@ class Worker:
|
|||||||
stage_outputs,
|
stage_outputs,
|
||||||
checkpoint=use_checkpoint)
|
checkpoint=use_checkpoint)
|
||||||
|
|
||||||
consume_result = [consume_result] if isinstance(consume_result, torch.Tensor) else consume_result
|
consume_result = [consume_result] if isinstance(consume_result,
|
||||||
|
(torch.Tensor, int, float)) else consume_result
|
||||||
|
|
||||||
# if not forward_only, do the backward
|
# if not forward_only, do the backward
|
||||||
if not forward_only:
|
if not forward_only:
|
||||||
@ -488,20 +482,21 @@ class Worker:
|
|||||||
|
|
||||||
stage_outputs = backward_cache.stage_outputs
|
stage_outputs = backward_cache.stage_outputs
|
||||||
stage_inputs = backward_cache.stage_inputs
|
stage_inputs = backward_cache.stage_inputs
|
||||||
grad_tensors = args
|
|
||||||
|
|
||||||
use_checkpoint = backward_cache.checkpoint
|
use_checkpoint = backward_cache.checkpoint
|
||||||
|
|
||||||
if use_checkpoint:
|
if use_checkpoint:
|
||||||
stage_outputs = [self.module_partition(*stage_inputs)]
|
stage_outputs = [self.module_partition(*stage_inputs)]
|
||||||
|
# overlap recompute and future.wait
|
||||||
|
grad_tensors = get_real_args(args)
|
||||||
|
|
||||||
autograd.backward(stage_outputs, grad_tensors=grad_tensors)
|
autograd.backward(stage_outputs, grad_tensors=grad_tensors)
|
||||||
|
|
||||||
# collect grad of input tensor
|
# collect grad of input tensor
|
||||||
consume_result = []
|
consume_result = []
|
||||||
for input_node in stage_inputs:
|
if not is_first_stage:
|
||||||
if isinstance(input_node, torch.Tensor):
|
for input_node in stage_inputs:
|
||||||
consume_result.append(input_node.grad)
|
if isinstance(input_node, torch.Tensor):
|
||||||
|
consume_result.append(input_node.grad)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
raise TypeError(f"Unknown phase appears in _consume_work_item_by_phase {phase}")
|
raise TypeError(f"Unknown phase appears in _consume_work_item_by_phase {phase}")
|
||||||
@ -509,7 +504,20 @@ class Worker:
|
|||||||
return consume_result
|
return consume_result
|
||||||
|
|
||||||
def _get_store_len(self):
|
def _get_store_len(self):
|
||||||
return f'work_list:{len(self.work_list)} output_list:{len(self.output_list)} backward_cache:{len(self.microbatch_id_to_backward_cache)}'
|
return f'work_list:{len(self.work_list)} output_list:{len(self.output_list)} backward_cache:{len(self.microbatch_id_to_backward_cache)} label_cache:{len(self.microbatch_id_to_labels)}'
|
||||||
|
|
||||||
|
def _get_parameter_grad_sum(self):
|
||||||
|
grad_sum = 0
|
||||||
|
for p in self.module_partition.parameters():
|
||||||
|
if p.grad is not None:
|
||||||
|
grad_sum += p.grad.sum()
|
||||||
|
return grad_sum
|
||||||
|
|
||||||
|
def _is_first_step(self, work_item) -> bool:
|
||||||
|
return work_item.phase == Phase.FORWARD and work_item.microbatch_id == 0
|
||||||
|
|
||||||
|
def _is_last_step(self, work_item) -> bool:
|
||||||
|
return work_item.phase == Phase.BACKWARD and work_item.microbatch_id == self.num_microbatches - 1
|
||||||
|
|
||||||
# do the main loop to consume ready_list
|
# do the main loop to consume ready_list
|
||||||
def _work_loop(self):
|
def _work_loop(self):
|
||||||
@ -519,8 +527,6 @@ class Worker:
|
|||||||
# main loop
|
# main loop
|
||||||
while True:
|
while True:
|
||||||
work_item_key = self._get_work_item_key()
|
work_item_key = self._get_work_item_key()
|
||||||
if work_item_key is None:
|
|
||||||
continue
|
|
||||||
|
|
||||||
# move current work item to output_list to activate subscribe in advance
|
# move current work item to output_list to activate subscribe in advance
|
||||||
with self.work_list_condition_lock:
|
with self.work_list_condition_lock:
|
||||||
@ -540,20 +546,32 @@ class Worker:
|
|||||||
color_debug(
|
color_debug(
|
||||||
f'rank_{self.pp_rank} [{work_item.phase}] finish consuming, result is {tensor_shape_list(consume_result)} {self._get_store_len()}',
|
f'rank_{self.pp_rank} [{work_item.phase}] finish consuming, result is {tensor_shape_list(consume_result)} {self._get_store_len()}',
|
||||||
'work loop', 'green')
|
'work loop', 'green')
|
||||||
|
|
||||||
work_item.output.set_result(consume_result)
|
work_item.output.set_result(consume_result)
|
||||||
|
|
||||||
|
# if is last step in one batch reset context and do step
|
||||||
|
if self._is_last_step(work_item):
|
||||||
|
if hasattr(self, 'optimizer'):
|
||||||
|
self.step()
|
||||||
|
self.forward_times = 0
|
||||||
|
self.backward_times = 0
|
||||||
|
self.outstanding = 0
|
||||||
|
self._initialize_outstanding_range()
|
||||||
|
|
||||||
def initialize_optimizer(self, optimizer_class: type, **kwargs):
|
def initialize_optimizer(self, optimizer_class: type, **kwargs):
|
||||||
self.optimizer: optim.Optimizer = optimizer_class(self.module_partition.parameters(), **kwargs)
|
self.optimizer: optim.Optimizer = optimizer_class(self.module_partition.parameters(), **kwargs)
|
||||||
|
|
||||||
def step(self):
|
def wait_for_step(self):
|
||||||
assert hasattr(self, "optimizer"), "call initialize_optimizer first before you call step!"
|
self.step_lock.acquire()
|
||||||
self.work_list.clear()
|
|
||||||
self.output_list.clear()
|
|
||||||
self.microbatch_id_to_backward_cache.clear()
|
|
||||||
|
|
||||||
|
def step(self):
|
||||||
|
# print(f'rank_{self.pp_rank}', sum([p.sum() for p in self.module_partition.parameters()]))
|
||||||
self.optimizer.step()
|
self.optimizer.step()
|
||||||
|
# print(f'rank_{self.pp_rank}', sum([p.sum() for p in self.module_partition.parameters()]))
|
||||||
self.optimizer.zero_grad()
|
self.optimizer.zero_grad()
|
||||||
|
|
||||||
|
self.step_lock.release()
|
||||||
|
|
||||||
|
|
||||||
class PipelineEngineBase(ABC, nn.Module):
|
class PipelineEngineBase(ABC, nn.Module):
|
||||||
|
|
||||||
@ -564,10 +582,12 @@ class PipelineEngineBase(ABC, nn.Module):
|
|||||||
device: str,
|
device: str,
|
||||||
use_1F1B=False,
|
use_1F1B=False,
|
||||||
chunk: int = 1,
|
chunk: int = 1,
|
||||||
|
criterion: Callable = None,
|
||||||
checkpoint: bool = False) -> None:
|
checkpoint: bool = False) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.module_partitions: List[nn.Module] = module_partitions
|
self.module_partitions: List[nn.Module] = module_partitions
|
||||||
self.chunk = chunk
|
self.chunk = chunk
|
||||||
|
self.criterion = criterion
|
||||||
self.num_microbatches = num_microbatches
|
self.num_microbatches = num_microbatches
|
||||||
self.device = device
|
self.device = device
|
||||||
self.use_1F1B = use_1F1B
|
self.use_1F1B = use_1F1B
|
||||||
@ -577,6 +597,8 @@ class PipelineEngineBase(ABC, nn.Module):
|
|||||||
|
|
||||||
self.pp_rank_to_worker_rref: Dict[int, PyRRef] = dict()
|
self.pp_rank_to_worker_rref: Dict[int, PyRRef] = dict()
|
||||||
|
|
||||||
|
self.step_futs: List[Future] = []
|
||||||
|
|
||||||
self._check_argument()
|
self._check_argument()
|
||||||
self._create_pp_rank_to_rpc_worker_id()
|
self._create_pp_rank_to_rpc_worker_id()
|
||||||
self._init_worker()
|
self._init_worker()
|
||||||
@ -613,6 +635,7 @@ class PipelineEngineBase(ABC, nn.Module):
|
|||||||
checkpoint = self.checkpoint
|
checkpoint = self.checkpoint
|
||||||
num_microbatches = self.num_microbatches
|
num_microbatches = self.num_microbatches
|
||||||
device = self.device
|
device = self.device
|
||||||
|
criterion = self.criterion
|
||||||
|
|
||||||
for pp_rank in range(actual_stage_num):
|
for pp_rank in range(actual_stage_num):
|
||||||
module_partition = self.module_partitions[pp_rank]
|
module_partition = self.module_partitions[pp_rank]
|
||||||
@ -622,7 +645,8 @@ class PipelineEngineBase(ABC, nn.Module):
|
|||||||
self.pp_rank_to_worker_rref[pp_rank] = rpc.remote(rpc_worker_id,
|
self.pp_rank_to_worker_rref[pp_rank] = rpc.remote(rpc_worker_id,
|
||||||
Worker,
|
Worker,
|
||||||
args=(module_partition, pp_rank, actual_stage_num,
|
args=(module_partition, pp_rank, actual_stage_num,
|
||||||
num_microbatches, use_1F1B, device, checkpoint))
|
num_microbatches, use_1F1B, device, criterion,
|
||||||
|
checkpoint))
|
||||||
|
|
||||||
# let each worker know global worker rref (include itself)
|
# let each worker know global worker rref (include itself)
|
||||||
for pp_rank in range(actual_stage_num):
|
for pp_rank in range(actual_stage_num):
|
||||||
@ -646,12 +670,15 @@ class PipelineEngineBase(ABC, nn.Module):
|
|||||||
grads[stage_id].append(grad)
|
grads[stage_id].append(grad)
|
||||||
return grads
|
return grads
|
||||||
|
|
||||||
def forward_backward(self, batch: torch.Tensor, forward_only: bool = False):
|
def forward_backward(self, batch: torch.Tensor, labels: torch.Tensor = None, forward_only: bool = False):
|
||||||
|
if labels is not None:
|
||||||
|
assert len(batch) == len(labels)
|
||||||
|
|
||||||
num_microbatches = self.num_microbatches
|
num_microbatches = self.num_microbatches
|
||||||
microbatch_size = len(batch) // num_microbatches
|
microbatch_size = len(batch) // num_microbatches
|
||||||
actual_stage_num = self._get_actual_stage_num()
|
actual_stage_num = self._get_actual_stage_num()
|
||||||
|
|
||||||
first_stage_worker = self.pp_rank_to_worker_rref[0]
|
first_worker_rref = self.pp_rank_to_worker_rref[0]
|
||||||
last_worker_rref = self.pp_rank_to_worker_rref[actual_stage_num - 1]
|
last_worker_rref = self.pp_rank_to_worker_rref[actual_stage_num - 1]
|
||||||
|
|
||||||
microbatch_iter = range(num_microbatches)
|
microbatch_iter = range(num_microbatches)
|
||||||
@ -659,11 +686,7 @@ class PipelineEngineBase(ABC, nn.Module):
|
|||||||
microbatch_iter = tqdm(microbatch_iter)
|
microbatch_iter = tqdm(microbatch_iter)
|
||||||
|
|
||||||
ret_future: List[Future] = [None] * num_microbatches
|
ret_future: List[Future] = [None] * num_microbatches
|
||||||
from time import sleep
|
|
||||||
|
|
||||||
for microbatch_id in microbatch_iter:
|
for microbatch_id in microbatch_iter:
|
||||||
microbatch = batch[microbatch_size * microbatch_id:microbatch_size * (microbatch_id + 1)]
|
|
||||||
|
|
||||||
# control data input speed
|
# control data input speed
|
||||||
# to prevent exceed of wait limitations
|
# to prevent exceed of wait limitations
|
||||||
if microbatch_id >= actual_stage_num:
|
if microbatch_id >= actual_stage_num:
|
||||||
@ -671,15 +694,27 @@ class PipelineEngineBase(ABC, nn.Module):
|
|||||||
ret_future[microbatch_id - actual_stage_num].wait()
|
ret_future[microbatch_id - actual_stage_num].wait()
|
||||||
else:
|
else:
|
||||||
key = UniqueKey(microbatch_id - actual_stage_num, Phase.BACKWARD)
|
key = UniqueKey(microbatch_id - actual_stage_num, Phase.BACKWARD)
|
||||||
first_stage_worker.rpc_sync().get_output_by_key(key)
|
first_worker_rref.rpc_sync().get_output_by_key(key)
|
||||||
|
|
||||||
# run one microbatch
|
# set input
|
||||||
first_stage_worker.rpc_sync().set_input(microbatch_id, microbatch, forward_only)
|
microbatch = batch[microbatch_size * microbatch_id:microbatch_size * (microbatch_id + 1)]
|
||||||
|
microbatch = microbatch.cuda()
|
||||||
|
first_worker_rref.remote().set_input(microbatch_id, microbatch, forward_only)
|
||||||
|
# set labels
|
||||||
|
if not forward_only and labels is not None:
|
||||||
|
microlabels = labels[microbatch_size * microbatch_id:microbatch_size * (microbatch_id + 1)]
|
||||||
|
microlabels = microlabels.cuda()
|
||||||
|
last_worker_rref.remote().set_labels(microbatch_id, microlabels)
|
||||||
|
|
||||||
key = UniqueKey(microbatch_id, Phase.FORWARD)
|
key = UniqueKey(microbatch_id, Phase.FORWARD)
|
||||||
ret_future[microbatch_id] = last_worker_rref.rpc_async().get_output_by_key(key)
|
ret_future[microbatch_id] = last_worker_rref.rpc_async().get_output_by_key(key)
|
||||||
|
|
||||||
# wait forward
|
# wait for last backward in rank0
|
||||||
|
if not forward_only:
|
||||||
|
key = UniqueKey(self.num_microbatches - 1, Phase.BACKWARD)
|
||||||
|
first_worker_rref.rpc_sync().get_output_by_key(key)
|
||||||
|
|
||||||
|
# collect forward result
|
||||||
# TODO : all the node to output
|
# TODO : all the node to output
|
||||||
forward_result = None
|
forward_result = None
|
||||||
|
|
||||||
@ -691,28 +726,30 @@ class PipelineEngineBase(ABC, nn.Module):
|
|||||||
for i in range(len(forward_result)):
|
for i in range(len(forward_result)):
|
||||||
forward_result[i].append(ret[i])
|
forward_result[i].append(ret[i])
|
||||||
|
|
||||||
# wait for last backward in rank0
|
if hasattr(self, 'optimizer_class'):
|
||||||
if not forward_only:
|
# wait for all step
|
||||||
key = UniqueKey(self.num_microbatches - 1, Phase.BACKWARD)
|
# TODO : more elegant ?
|
||||||
first_stage_worker.rpc_sync().get_output_by_key(key)
|
for pp_rank in self.pp_rank_to_worker_rref:
|
||||||
|
worker_rref = self.pp_rank_to_worker_rref[pp_rank]
|
||||||
|
worker_rref.rpc_sync().wait_for_step()
|
||||||
|
|
||||||
return forward_result
|
return forward_result
|
||||||
|
|
||||||
def initialize_optimizer(self, optimizer_class: type, **kwargs):
|
def initialize_optimizer(self, optimizer_class: type, **kwargs):
|
||||||
actual_stage_num = self._get_actual_stage_num()
|
actual_stage_num = self._get_actual_stage_num()
|
||||||
|
self.optimizer_class = optimizer_class
|
||||||
for pp_rank in range(actual_stage_num):
|
for pp_rank in range(actual_stage_num):
|
||||||
worker_rref = self.pp_rank_to_worker_rref[pp_rank]
|
worker_rref = self.pp_rank_to_worker_rref[pp_rank]
|
||||||
worker_rref.rpc_sync().initialize_optimizer(optimizer_class, **kwargs)
|
worker_rref.remote().initialize_optimizer(optimizer_class, **kwargs)
|
||||||
|
|
||||||
def step(self):
|
def step(self):
|
||||||
step_futs: List[Future] = []
|
|
||||||
actual_stage_num = self._get_actual_stage_num()
|
actual_stage_num = self._get_actual_stage_num()
|
||||||
for pp_rank in range(actual_stage_num):
|
for pp_rank in range(actual_stage_num):
|
||||||
worker_rref = self.pp_rank_to_worker_rref[pp_rank]
|
worker_rref = self.pp_rank_to_worker_rref[pp_rank]
|
||||||
fut = worker_rref.rpc_async().step()
|
fut = worker_rref.rpc_async().step()
|
||||||
step_futs.append(fut)
|
self.step_futs.append(fut)
|
||||||
|
|
||||||
# wait for all optimizers
|
for fut in self.step_futs:
|
||||||
for fut in step_futs:
|
|
||||||
fut.wait()
|
fut.wait()
|
||||||
|
|
||||||
|
|
||||||
@ -724,9 +761,10 @@ class FillDrainPipelineEngine(PipelineEngineBase):
|
|||||||
num_microbatches: int,
|
num_microbatches: int,
|
||||||
device: str,
|
device: str,
|
||||||
chunk: int = 1,
|
chunk: int = 1,
|
||||||
|
criterion: Callable = None,
|
||||||
checkpoint: bool = False) -> None:
|
checkpoint: bool = False) -> None:
|
||||||
use_1F1B = False
|
use_1F1B = False
|
||||||
super().__init__(module_partitions, stage_num, num_microbatches, device, use_1F1B, chunk, checkpoint)
|
super().__init__(module_partitions, stage_num, num_microbatches, device, use_1F1B, chunk, criterion, checkpoint)
|
||||||
|
|
||||||
|
|
||||||
class OneFOneBPipelineEngine(PipelineEngineBase):
|
class OneFOneBPipelineEngine(PipelineEngineBase):
|
||||||
@ -737,6 +775,7 @@ class OneFOneBPipelineEngine(PipelineEngineBase):
|
|||||||
num_microbatches: int,
|
num_microbatches: int,
|
||||||
device: str,
|
device: str,
|
||||||
chunk: int = 1,
|
chunk: int = 1,
|
||||||
|
criterion: Callable = None,
|
||||||
checkpoint: bool = False) -> None:
|
checkpoint: bool = False) -> None:
|
||||||
use_1F1B = True
|
use_1F1B = True
|
||||||
super().__init__(module_partitions, stage_num, num_microbatches, device, use_1F1B, chunk, checkpoint)
|
super().__init__(module_partitions, stage_num, num_microbatches, device, use_1F1B, chunk, criterion, checkpoint)
|
@ -45,6 +45,7 @@ def parse_args():
|
|||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument('--epoch', type=int, default=1)
|
parser.add_argument('--epoch', type=int, default=1)
|
||||||
parser.add_argument('--world_size', type=int, default=2)
|
parser.add_argument('--world_size', type=int, default=2)
|
||||||
|
parser.add_argument('--batch_size', type=int, default=16)
|
||||||
parser.add_argument('--num_microbatches', type=int, default=2)
|
parser.add_argument('--num_microbatches', type=int, default=2)
|
||||||
parser.add_argument('--chunk', type=int, default=1)
|
parser.add_argument('--chunk', type=int, default=1)
|
||||||
parser.add_argument('--use_checkpoint', action='store_true')
|
parser.add_argument('--use_checkpoint', action='store_true')
|
||||||
|
@ -43,7 +43,6 @@ def run_master(args):
|
|||||||
engine.initialize_optimizer(optimizer_class, lr=lr)
|
engine.initialize_optimizer(optimizer_class, lr=lr)
|
||||||
|
|
||||||
_ = engine.forward_backward(input_sample)
|
_ = engine.forward_backward(input_sample)
|
||||||
engine.step()
|
|
||||||
|
|
||||||
cuda_rpc_result = []
|
cuda_rpc_result = []
|
||||||
single_result = []
|
single_result = []
|
||||||
|
Loading…
Reference in New Issue
Block a user