[PP Middleware] Add bwd and step for PP middleware (#2111)

* add bwd and step for PP middleware

* pre-commit

Co-authored-by: Ziyue Jiang <ziyue.jiang@gmail.com>
This commit is contained in:
Ziyue Jiang 2022-12-12 12:40:03 +08:00 committed by GitHub
parent 8afc001f4f
commit 09d69e1c25
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 225 additions and 82 deletions

View File

@ -8,20 +8,29 @@ from typing import Any, Callable, Dict, List, Tuple
import torch import torch
import torch.distributed.rpc as rpc import torch.distributed.rpc as rpc
from colossalai.pipeline.pipeline_process_group import ppg
from colossalai.pipeline.rpc.utils import (get_batch_lengths, pytree_filter, pytree_map,
split_batch, tensor_shape_list, type_detail)
from colossalai.pipeline.middleware import Partition, PartitionInputVal, PartitionOutputVal, Topo
from torch import autograd, nn, optim from torch import autograd, nn, optim
from torch._C._distributed_rpc import PyRRef from torch._C._distributed_rpc import PyRRef
from torch.futures import Future from torch.futures import Future
from colossalai.pipeline.middleware import Partition, PartitionInputVal, PartitionOutputVal, Topo
from colossalai.pipeline.pipeline_process_group import ppg
from colossalai.pipeline.rpc.utils import (
get_batch_lengths,
pytree_filter,
pytree_map,
split_batch,
tensor_shape_list,
type_detail,
)
class Phase(Enum): class Phase(Enum):
FORWARD = 0 FORWARD = 0
BACKWARD = 1 BACKWARD = 1
UPDATE = 2 UPDATE = 2
INPUT = 3 INPUT = 3
class UniqueKey: class UniqueKey:
__slots__ = ('microbatch_id', 'phase') __slots__ = ('microbatch_id', 'phase')
microbatch_id: int microbatch_id: int
@ -134,6 +143,7 @@ class WorkerBase(ABC):
self.partition_args = partition_args self.partition_args = partition_args
self.criterion = criterion self.criterion = criterion
self.metric = metric self.metric = metric
self.reset = False
# context to maintain loop # context to maintain loop
self._initialize_context_container() self._initialize_context_container()
@ -164,6 +174,7 @@ class WorkerBase(ABC):
self.work_list_condition_lock = threading.Condition(threading.Lock()) self.work_list_condition_lock = threading.Condition(threading.Lock())
self.output_list_condition_lock = threading.Condition(threading.Lock()) self.output_list_condition_lock = threading.Condition(threading.Lock())
self.label_lock = threading.Condition(threading.Lock()) self.label_lock = threading.Condition(threading.Lock())
self.reset_condition = threading.Condition(threading.Lock())
def _initialize_partition(self): def _initialize_partition(self):
partition_fn = self.partition_fn partition_fn = self.partition_fn
@ -182,20 +193,23 @@ class WorkerBase(ABC):
# construction of partition is executed after the registion of pp_rank_to_worker_rref # construction of partition is executed after the registion of pp_rank_to_worker_rref
self._initialize_partition() self._initialize_partition()
def get_output_by_key(self, key: UniqueKey, recv_rank=None) -> Any: # res_use works for lifecycle counter,
# if ref_use is True, lifecycle won't add.
def get_output_by_key(self, key: UniqueKey, ref_use=False) -> Any:
with self.output_list_condition_lock: with self.output_list_condition_lock:
self.output_list_condition_lock.wait_for(lambda: key in self.output_list) self.output_list_condition_lock.wait_for(lambda: key in self.output_list)
output_work_item = self.output_list[key] output_work_item = self.output_list[key]
self.output_list.pop(key) self.output_list.pop(key)
output_work_item.refcount += 1 if not ref_use:
output_work_item.refcount += 1
refcount = output_work_item.refcount refcount = output_work_item.refcount
output = output_work_item.output output = output_work_item.output
if output_work_item.phase != Phase.INPUT: if output_work_item.phase == Phase.FORWARD:
# lifecycle management for DAG scheduler # lifecycle management for DAG scheduler
lifecycle = len(self.get_consumer_stage_ids()) lifecycle = len(self.get_consumer_stage_ids())
if self.is_model_output(): # an extra reference for scheduler collecting results if self.is_model_output(): # an extra reference for scheduler collecting results
lifecycle += 1 lifecycle += 1
with self.output_list_condition_lock: with self.output_list_condition_lock:
# all consumers have been satisfied, the work_item can be released # all consumers have been satisfied, the work_item can be released
@ -203,14 +217,24 @@ class WorkerBase(ABC):
if refcount < lifecycle: if refcount < lifecycle:
self.output_list[key] = output_work_item self.output_list[key] = output_work_item
self.output_list_condition_lock.notify_all() self.output_list_condition_lock.notify_all()
elif output_work_item.phase == Phase.BACKWARD:
lifecycle = len(self.get_producer_stage_ids())
if self._is_last_step(output_work_item):
lifecycle += 1 # an extra reference for scheduler collecting results
with self.output_list_condition_lock:
# all producers have been satisfied, the work_item can be released
# or put it into work list again.
if refcount < lifecycle:
self.output_list[key] = output_work_item
self.output_list_condition_lock.notify_all()
else: else:
with self.output_list_condition_lock: with self.output_list_condition_lock:
self.output_list[key] = output_work_item self.output_list[key] = output_work_item
self.output_list_condition_lock.notify_all() self.output_list_condition_lock.notify_all()
if isinstance(output, Future): if isinstance(output, Future):
output = output.wait() output = output.wait()
return output return output
def get_parameters(self) -> List[torch.Tensor]: def get_parameters(self) -> List[torch.Tensor]:
@ -257,13 +281,13 @@ class WorkerBase(ABC):
def set_input(self, microbatch_id: int, microbatch: Tuple[Any], forward_only: bool): def set_input(self, microbatch_id: int, microbatch: Tuple[Any], forward_only: bool):
key = UniqueKey(microbatch_id, Phase.FORWARD) key = UniqueKey(microbatch_id, Phase.FORWARD)
output = self._get_future_by_device() output = self._get_future_by_device()
if not self.use_middleware(): if not self.use_middleware():
# make args and kwargs # make args and kwargs
args, kwargs = self._make_args_kwargs(microbatch) args, kwargs = self._make_args_kwargs(microbatch)
work_item = WorkItem(self.pp_rank, Phase.FORWARD, args, kwargs, output, microbatch_id, None, work_item = WorkItem(self.pp_rank, Phase.FORWARD, args, kwargs, output, microbatch_id, None,
self.num_microbatches, forward_only) self.num_microbatches, forward_only)
with self.work_list_condition_lock: with self.work_list_condition_lock:
self.work_list[key] = work_item self.work_list[key] = work_item
self.work_list_condition_lock.notify_all() self.work_list_condition_lock.notify_all()
@ -284,14 +308,14 @@ class WorkerBase(ABC):
self_arg_lst.append(arg_lst[off]) self_arg_lst.append(arg_lst[off])
work_item = WorkItem(self.pp_rank, Phase.FORWARD, self_arg_lst, {}, output, microbatch_id, None, work_item = WorkItem(self.pp_rank, Phase.FORWARD, self_arg_lst, {}, output, microbatch_id, None,
self.num_microbatches, forward_only) self.num_microbatches, forward_only)
with self.work_list_condition_lock: with self.work_list_condition_lock:
self.work_list[key] = work_item self.work_list[key] = work_item
self.work_list_condition_lock.notify_all() self.work_list_condition_lock.notify_all()
# put input tensor which other nodes need into output_list as Phase.INPUT # put input tensor which other nodes need into output_list as Phase.INPUT
work_item_remote = WorkItem(self.pp_rank, Phase.INPUT, [], {}, arg_lst, microbatch_id, None, work_item_remote = WorkItem(self.pp_rank, Phase.INPUT, [], {}, arg_lst, microbatch_id, None,
self.num_microbatches, forward_only) self.num_microbatches, forward_only)
with self.output_list_condition_lock: with self.output_list_condition_lock:
self.output_list[recv_input_key] = work_item_remote self.output_list[recv_input_key] = work_item_remote
@ -317,7 +341,7 @@ class WorkerBase(ABC):
self.work_list[key] = work_item self.work_list[key] = work_item
self.work_list_condition_lock.notify_all() self.work_list_condition_lock.notify_all()
def _subscribe_producer(self, microbatch_id: int, forward_only: bool): def _subscribe_producer(self, microbatch_id: int, forward_only: bool):
""" """
You should call this function asynchronously You should call this function asynchronously
@ -336,7 +360,7 @@ class WorkerBase(ABC):
producer_stage_ids = self.get_producer_stage_ids() producer_stage_ids = self.get_producer_stage_ids()
producer_num = len(producer_stage_ids) producer_num = len(producer_stage_ids)
if self.need_model_input(): if self.need_model_input():
producer_num += 1 # for input partition producer_num += 1 # for input partition
subscribe_forward_futures: List[Future] = [None] * producer_num subscribe_forward_futures: List[Future] = [None] * producer_num
# TODO(jiangziyue) get single value instead of the whole output # TODO(jiangziyue) get single value instead of the whole output
@ -344,26 +368,28 @@ class WorkerBase(ABC):
producer_stage_id = 0 producer_stage_id = 0
producer_output_key = UniqueKey(microbatch_id, Phase.INPUT) producer_output_key = UniqueKey(microbatch_id, Phase.INPUT)
producer_worker_rref = self.pp_rank_to_worker_rref[producer_stage_id] producer_worker_rref = self.pp_rank_to_worker_rref[producer_stage_id]
subscribe_forward_futures[0] = producer_worker_rref.rpc_async().get_output_by_key(producer_output_key, self.pp_rank) subscribe_forward_futures[0] = producer_worker_rref.rpc_async().get_output_by_key(producer_output_key)
for i in range(0, producer_num-1): for i in range(0, producer_num - 1):
producer_stage_id = producer_stage_ids[i] producer_stage_id = producer_stage_ids[i]
producer_output_key = UniqueKey(microbatch_id, Phase.FORWARD) producer_output_key = UniqueKey(microbatch_id, Phase.FORWARD)
producer_worker_rref = self.pp_rank_to_worker_rref[producer_stage_id] producer_worker_rref = self.pp_rank_to_worker_rref[producer_stage_id]
subscribe_forward_futures[i+1] = producer_worker_rref.rpc_async().get_output_by_key(producer_output_key, self.pp_rank) subscribe_forward_futures[i + 1] = producer_worker_rref.rpc_async().get_output_by_key(
producer_output_key)
else: else:
for i in range(producer_num): for i in range(producer_num):
producer_stage_id = producer_stage_ids[i] producer_stage_id = producer_stage_ids[i]
producer_output_key = UniqueKey(microbatch_id, Phase.FORWARD) producer_output_key = UniqueKey(microbatch_id, Phase.FORWARD)
producer_worker_rref = self.pp_rank_to_worker_rref[producer_stage_id] producer_worker_rref = self.pp_rank_to_worker_rref[producer_stage_id]
subscribe_forward_futures[i] = producer_worker_rref.rpc_async().get_output_by_key(producer_output_key, self.pp_rank) subscribe_forward_futures[i] = producer_worker_rref.rpc_async().get_output_by_key(
producer_output_key)
work_item_from_producer = WorkItem(stage_id, Phase.FORWARD, subscribe_forward_futures, {}, output, work_item_from_producer = WorkItem(stage_id, Phase.FORWARD, subscribe_forward_futures, {}, output,
microbatch_id, None, self.num_microbatches, forward_only) microbatch_id, None, self.num_microbatches, forward_only)
return work_item_from_producer return work_item_from_producer
# TODO(jiangziyue) Profile the side effect of the lock for lifecycle protection and consider a better one. # TODO(jiangziyue) Profile the side effect of the lock for lifecycle protection and consider a better one.
def subscribe_producer(self, microbatch_id: int, forward_only: bool): def subscribe_producer(self, microbatch_id: int, forward_only: bool):
key = UniqueKey(microbatch_id, Phase.FORWARD) key = UniqueKey(microbatch_id, Phase.FORWARD)
@ -377,20 +403,20 @@ class WorkerBase(ABC):
self.work_list[key] = work_item_from_producer self.work_list[key] = work_item_from_producer
self.work_list_condition_lock.notify_all() self.work_list_condition_lock.notify_all()
def subscribe_consumer(self, microbatch_id: int): def _subscribe_consumer(self, microbatch_id: int):
""" """
You should call this function asynchronously You should call this function asynchronously
""" """
assert self.producer_stage_ids is not None
consumer_num = len(self.consumer_stage_ids)
assert consumer_num > 0, "only stage that has consumers can subscribe comsumers"
stage_id = self.pp_rank stage_id = self.pp_rank
subscribe_backward_futures: List[Future] = [None] * consumer_num
output = self._get_future_by_device() output = self._get_future_by_device()
if not self.use_middleware():
consumer_stage_ids = self.consumer_stage_ids
else:
consumer_stage_ids = self.get_consumer_stage_ids()
consumer_num = len(consumer_stage_ids)
subscribe_backward_futures: List[Future] = [None] * consumer_num
for i in range(consumer_num): for i in range(consumer_num):
consumer_stage_id = self.consumer_stage_ids[i] consumer_stage_id = consumer_stage_ids[i]
consumer_output_key = UniqueKey(microbatch_id, Phase.BACKWARD) consumer_output_key = UniqueKey(microbatch_id, Phase.BACKWARD)
consumer_worker_rref = self.pp_rank_to_worker_rref[consumer_stage_id] consumer_worker_rref = self.pp_rank_to_worker_rref[consumer_stage_id]
subscribe_backward_futures[i] = consumer_worker_rref.rpc_async().get_output_by_key(consumer_output_key) subscribe_backward_futures[i] = consumer_worker_rref.rpc_async().get_output_by_key(consumer_output_key)
@ -399,13 +425,20 @@ class WorkerBase(ABC):
work_item_from_consumer = WorkItem(stage_id, Phase.BACKWARD, subscribe_backward_futures, {}, output, work_item_from_consumer = WorkItem(stage_id, Phase.BACKWARD, subscribe_backward_futures, {}, output,
microbatch_id, None, self.num_microbatches, False) microbatch_id, None, self.num_microbatches, False)
# add work_item to work_list return work_item_from_consumer
def subscribe_consumer(self, microbatch_id: int):
key = UniqueKey(microbatch_id, Phase.BACKWARD)
with self.work_list_condition_lock: with self.work_list_condition_lock:
key = UniqueKey(microbatch_id, Phase.BACKWARD) if key not in self.work_list:
assert key not in self.work_list # On current PP middleware design for DAG, get_output_by_key used by subscribe_consumer
self.work_list[key] = work_item_from_consumer # can only be executed once for every producer-consumer stage pair, which is necessary
self.work_list_condition_lock.notify_all() # to count the lifecycle of work_item. So, keeping the subscribe_consumer in the same
# lock of work_item queue operation gurantees the consistency of lifecycle counter.
work_item_from_consumer = self._subscribe_consumer(microbatch_id)
self.work_list[key] = work_item_from_consumer
self.work_list_condition_lock.notify_all()
def get_producer_stage_ids(self): def get_producer_stage_ids(self):
producer_stage_ids = [] producer_stage_ids = []
rank = self.pp_rank rank = self.pp_rank
@ -425,7 +458,7 @@ class WorkerBase(ABC):
if partition_id != model_input_partition_id: if partition_id != model_input_partition_id:
producer_stage_ids.append(self.partition_id_to_pp_rank(partition_id, topo)) producer_stage_ids.append(self.partition_id_to_pp_rank(partition_id, topo))
return producer_stage_ids return producer_stage_ids
def get_consumer_stage_ids(self): def get_consumer_stage_ids(self):
consumer_stage_ids = [] consumer_stage_ids = []
rank = self.pp_rank rank = self.pp_rank
@ -462,7 +495,7 @@ class WorkerBase(ABC):
for i, id in enumerate(partition_ids): for i, id in enumerate(partition_ids):
if id == partition_id: if id == partition_id:
return i return i
def get_topo(self): def get_topo(self):
with self.partition_condition_lock: with self.partition_condition_lock:
self.partition_condition_lock.wait_for(lambda: hasattr(self, 'module_partition')) self.partition_condition_lock.wait_for(lambda: hasattr(self, 'module_partition'))
@ -470,13 +503,13 @@ class WorkerBase(ABC):
return self.module_partition._topo return self.module_partition._topo
else: else:
return None return None
def use_middleware(self): def use_middleware(self):
topo = self.get_topo() topo = self.get_topo()
return topo is not None return topo is not None
# TODO(jiangziyue) get single value instead of the whole output # TODO(jiangziyue) get single value instead of the whole output
def _get_real_args_kwargs(self, args_or_kwargs): def _get_real_args_kwargs_fwd(self, args_or_kwargs):
if not self.use_middleware(): if not self.use_middleware():
args_or_kwargs = pytree_map(args_or_kwargs, fn=lambda x: x.wait(), process_types=Future) args_or_kwargs = pytree_map(args_or_kwargs, fn=lambda x: x.wait(), process_types=Future)
if args_or_kwargs is not None: if args_or_kwargs is not None:
@ -491,8 +524,8 @@ class WorkerBase(ABC):
if args_or_kwargs is not None: if args_or_kwargs is not None:
if isinstance(args_or_kwargs, dict): if isinstance(args_or_kwargs, dict):
pass pass
else: else:
flatten_args = [] flatten_args = []
if self.is_first_stage(): if self.is_first_stage():
pytree_map(args_or_kwargs, fn=lambda x: flatten_args.append(x), map_all=True) pytree_map(args_or_kwargs, fn=lambda x: flatten_args.append(x), map_all=True)
# TODO get by offset # TODO get by offset
@ -525,7 +558,7 @@ class WorkerBase(ABC):
if stage_id == src_stage_id: if stage_id == src_stage_id:
src_index += i src_index += i
break break
else: # data from input partition else: # data from input partition
src_index = 0 src_index = 0
# when output_len = 1, not iterable # when output_len = 1, not iterable
if output_len == 1: if output_len == 1:
@ -536,6 +569,55 @@ class WorkerBase(ABC):
args_or_kwargs = flatten_args args_or_kwargs = flatten_args
return args_or_kwargs return args_or_kwargs
# TODO(jiangziyue) get single value instead of the whole output
def _get_real_args_kwargs_bwd(self, args_or_kwargs):
if not self.use_middleware():
args_or_kwargs = pytree_map(args_or_kwargs, fn=lambda x: x.wait(), process_types=Future)
if args_or_kwargs is not None:
if isinstance(args_or_kwargs, dict):
pass
else:
flatten_args = []
pytree_map(args_or_kwargs, fn=lambda x: flatten_args.append(x), map_all=True)
args_or_kwargs = flatten_args
else:
args_or_kwargs = pytree_map(args_or_kwargs, fn=lambda x: x.wait(), process_types=Future)
if args_or_kwargs is not None:
flatten_args = []
# TODO get by offset
topo: Topo = self.get_topo()
self_partition_id = self.pp_rank_to_partition_id(self.pp_rank, topo)
self_partition: Partition = topo.get_partition_by_id(self_partition_id)
output_vals = self_partition.get_output_vals()
consumer_stage_ids = self.get_consumer_stage_ids()
for val_list in output_vals:
# An output may be passed to many down stages.
target = None
for val_pos in val_list.get():
dst_partition_id = val_pos.partition_id
dst_offset = val_pos.offset
dst_partition = topo.get_partition_by_id(dst_partition_id)
input_len = len(dst_partition.get_input_vals())
dst_stage_id = self.partition_id_to_pp_rank(dst_partition_id, topo)
for i, stage_id in enumerate(consumer_stage_ids):
if stage_id == dst_stage_id:
dst_index = i
break
if input_len == 1:
part_grad = args_or_kwargs[dst_index]
else:
part_grad = args_or_kwargs[dst_index][dst_offset]
if target is None:
target = part_grad
elif part_grad is not None:
target += part_grad
else:
continue
flatten_args.append(target)
args_or_kwargs = flatten_args
return args_or_kwargs
@abstractmethod @abstractmethod
def _get_work_item_key(self) -> UniqueKey: def _get_work_item_key(self) -> UniqueKey:
""" """
@ -547,7 +629,7 @@ class WorkerBase(ABC):
def is_last_stage(self): def is_last_stage(self):
return self.pp_rank == self.actual_stage_num - 1 return self.pp_rank == self.actual_stage_num - 1
def need_model_input(self): def need_model_input(self):
need_input = False need_input = False
topo: Topo = self.get_topo() topo: Topo = self.get_topo()
@ -558,10 +640,13 @@ class WorkerBase(ABC):
if model_input_partition_id in partition_inputs: if model_input_partition_id in partition_inputs:
need_input = True need_input = True
return not self.is_first_stage() and need_input return not self.is_first_stage() and need_input
def is_model_output(self): def is_model_output(self):
return self.is_last_stage() return self.is_last_stage()
def is_model_input(self):
return self.is_first_stage()
def _default_data_process_func(self, args_kwargs): def _default_data_process_func(self, args_kwargs):
if self.is_first_stage(): if self.is_first_stage():
args = args_kwargs[0] args = args_kwargs[0]
@ -598,11 +683,16 @@ class WorkerBase(ABC):
# parse and integrate args and kwargs # parse and integrate args and kwargs
if is_first_stage: if is_first_stage:
args = self._get_real_args_kwargs(args) args = self._get_real_args_kwargs_fwd(args)
kwargs = self._get_real_args_kwargs(kwargs) kwargs = self._get_real_args_kwargs_fwd(kwargs)
args_kwargs = (args, kwargs) args_kwargs = (args, kwargs)
else: else:
args_kwargs = self._get_real_args_kwargs(args) args_kwargs = self._get_real_args_kwargs_fwd(args)
if not forward_only:
pytree_map(args_kwargs,
lambda x: x.requires_grad_(True) if torch.is_floating_point(x) else x.requires_grad_(False),
process_types=torch.Tensor)
args, kwargs = data_process_func(args_kwargs) args, kwargs = data_process_func(args_kwargs)
@ -694,21 +784,40 @@ class WorkerBase(ABC):
# overlap recompute and future.wait # overlap recompute and future.wait
if not is_last_stage: if not is_last_stage:
grad_tensors = self._get_real_args_kwargs(args) grad_tensors = self._get_real_args_kwargs_bwd(args)
else: else:
grad_tensors = None grad_tensors = None
# take tensor only (for only tensor can do backward) # take tensor only (for only tensor can do backward)
stage_outputs = pytree_filter(lambda x: x.requires_grad, stage_outputs, process_types=torch.Tensor) # TODO(jiangziyue) : All values which should do bp are torch.Tensor?
grad_tensors = pytree_filter(lambda x: x is not None, grad_tensors, process_types=torch.Tensor) stage_outputs = pytree_filter(lambda x: True, stage_outputs, process_types=torch.Tensor)
grad_tensors = pytree_filter(lambda x: True, grad_tensors, process_types=torch.Tensor)
# output all input's grad to producer, even it has no grad(output None)
# to make the offset aligned to the topo's record.
if grad_tensors is not None:
filtered_outputs = []
filtered_grads = []
for i, grad in enumerate(grad_tensors):
stage_output = stage_outputs[i]
if stage_output.requires_grad and grad is not None:
filtered_outputs.append(stage_output)
filtered_grads.append(grad)
stage_outputs = filtered_outputs
grad_tensors = filtered_grads
autograd.backward(stage_outputs, grad_tensors=grad_tensors) autograd.backward(stage_outputs, grad_tensors=grad_tensors)
# collect grad of input tensor # collect grad of input tensor
consume_result = [] consume_result = []
if not is_first_stage: if not is_first_stage:
pytree_map(stage_input_args, lambda x: consume_result.append(x.grad), process_types=torch.Tensor) # In current design, input mush be a flatten args.
pytree_map(stage_input_kwargs, lambda x: consume_result.append(x.grad), process_types=torch.Tensor) for arg in stage_input_args:
if isinstance(arg, torch.Tensor):
consume_result.append(arg.grad)
else:
consume_result.append(None)
else: else:
raise TypeError(f"Unknown phase appears in _consume_work_item_by_phase {phase}") raise TypeError(f"Unknown phase appears in _consume_work_item_by_phase {phase}")
@ -740,11 +849,11 @@ class WorkerBase(ABC):
def _hook_before_step(self): def _hook_before_step(self):
pass pass
def _reset_context(self): # install the main loop to wait for next batch input
self.forward_times = 0 def _wait_for_reset(self):
self.backward_times = 0 with self.reset_condition:
self.outstanding = 0 self.reset_condition.wait_for(lambda: self.reset)
self._initialize_outstanding_range() self.reset = False
# do the main loop to consume ready_list # do the main loop to consume ready_list
def _work_loop(self): def _work_loop(self):
@ -755,10 +864,9 @@ class WorkerBase(ABC):
# main loop # main loop
while True: while True:
work_item_key = self._get_work_item_key() work_item_key = self._get_work_item_key()
# move current work item to output_list to activate subscribe in advance # move current work item to output_list to activate subscribe in advance
with self.work_list_condition_lock: with self.work_list_condition_lock:
#self.work_list_condition_lock.wait_for(lambda: work_item_key in self.work_list) self.work_list_condition_lock.wait_for(lambda: work_item_key in self.work_list)
work_item = self.work_list[work_item_key] work_item = self.work_list[work_item_key]
with self.output_list_condition_lock: with self.output_list_condition_lock:
@ -768,16 +876,32 @@ class WorkerBase(ABC):
consume_result = self._consume_work_item_by_phase(work_item) consume_result = self._consume_work_item_by_phase(work_item)
work_item.output.set_result(consume_result)
with self.work_list_condition_lock: with self.work_list_condition_lock:
self.work_list.pop(work_item_key) self.work_list.pop(work_item_key)
work_item.output.set_result(consume_result)
# if is last step in one batch reset context and do step # if is last step in one batch reset context and do step
if self._is_last_step(work_item): if self._is_last_step(work_item):
self._hook_before_step() self._hook_before_step()
if hasattr(self, 'optimizer') and not work_item.forward_only: if hasattr(self, 'optimizer') and not work_item.forward_only:
self.step() self.step()
self._reset_context() self._wait_for_reset()
# reset context and resume loop
def reset_context(self):
self.forward_times = 0
self.backward_times = 0
self.outstanding = 0
self._initialize_outstanding_range()
with self.work_list_condition_lock:
self.work_list.clear()
with self.output_list_condition_lock:
self.output_list.clear()
with self.reset_condition:
self.reset = True
self.reset_condition.notify_all()
def initialize_optimizer(self, optimizer_class: type, **kwargs): def initialize_optimizer(self, optimizer_class: type, **kwargs):
# TODO(jiangziyue) it's temporary code to deal with empty module partition. # TODO(jiangziyue) it's temporary code to deal with empty module partition.
@ -856,7 +980,7 @@ class PipelineEngineBase(ABC, nn.Module):
def _create_pp_rank_to_rpc_worker_id(self) -> None: def _create_pp_rank_to_rpc_worker_id(self) -> None:
"""create a map from model partition to stage_id, which is useful when use_interleave is True. """create a map from model partition to stage_id, which is useful when use_interleave is True.
e.g. If a model is splited into 4 parts, which means stage_num is 2, chunk is 2, then e.g. If a model is splited into 4 parts, which means stage_num is 2, chunk is 2, then
pp_rank_to_rpc_worker_id = [0, 1, 0, 1], that means first and third part pp_rank_to_rpc_worker_id = [0, 1, 0, 1], that means first and third part
of partitions will be moved to device 0 and the others to device 1 of partitions will be moved to device 0 and the others to device 1
""" """
@ -947,7 +1071,7 @@ class PipelineEngineBase(ABC, nn.Module):
key = UniqueKey(microbatch_id - actual_stage_num, Phase.BACKWARD) key = UniqueKey(microbatch_id - actual_stage_num, Phase.BACKWARD)
for pp_rank in input_pp_ranks: for pp_rank in input_pp_ranks:
worker_rref = self.pp_rank_to_worker_rref[pp_rank] worker_rref = self.pp_rank_to_worker_rref[pp_rank]
worker_rref.rpc_sync().get_output_by_key(key) worker_rref.rpc_sync().get_output_by_key(key, ref_use=True)
def _create_ret_future(self, output_pp_ranks: List[int]) -> Dict[int, List[Future]]: def _create_ret_future(self, output_pp_ranks: List[int]) -> Dict[int, List[Future]]:
num_microbatches = self.num_microbatches num_microbatches = self.num_microbatches
@ -965,6 +1089,7 @@ class PipelineEngineBase(ABC, nn.Module):
# TODO : add relationship between output_pp_ranks and parts of microlabels # TODO : add relationship between output_pp_ranks and parts of microlabels
worker_rref.remote().set_labels(microbatch_id, microlabels) worker_rref.remote().set_labels(microbatch_id, microlabels)
# TODO(jiangziyue) : get model output with single value, instead of merging into last stage.
def _subscribe_forward(self, microbatch_id: int, output_pp_ranks: List[int], ret_future: Dict[int, List[Future]]): def _subscribe_forward(self, microbatch_id: int, output_pp_ranks: List[int], ret_future: Dict[int, List[Future]]):
key = UniqueKey(microbatch_id, Phase.FORWARD) key = UniqueKey(microbatch_id, Phase.FORWARD)
for pp_rank in output_pp_ranks: for pp_rank in output_pp_ranks:
@ -993,6 +1118,16 @@ class PipelineEngineBase(ABC, nn.Module):
return forward_result return forward_result
def _reset_worker(self):
actual_stage_num = self._get_actual_stage_num()
for pp_rank in range(actual_stage_num):
worker_rref = self.pp_rank_to_worker_rref[pp_rank]
fut = worker_rref.rpc_async().reset_context()
self.step_futs.append(fut)
for fut in self.step_futs:
fut.wait()
def forward_backward(self, batch: torch.Tensor, labels: torch.Tensor = None, forward_only: bool = False): def forward_backward(self, batch: torch.Tensor, labels: torch.Tensor = None, forward_only: bool = False):
batch_lengths = get_batch_lengths(batch) batch_lengths = get_batch_lengths(batch)
batch_length = batch_lengths[0] batch_length = batch_lengths[0]
@ -1046,6 +1181,7 @@ class PipelineEngineBase(ABC, nn.Module):
worker_rref = self.pp_rank_to_worker_rref[pp_rank] worker_rref = self.pp_rank_to_worker_rref[pp_rank]
worker_rref.rpc_sync().wait_for_step() worker_rref.rpc_sync().wait_for_step()
self._reset_worker() # reset worker attributes for next batch
return forward_result return forward_result
def initialize_optimizer(self, optimizer_class: type, **kwargs): def initialize_optimizer(self, optimizer_class: type, **kwargs):

View File

@ -89,9 +89,6 @@ class OneFOneBWorker(WorkerBase):
elif target_key.microbatch_id == num_microbatches - 1: elif target_key.microbatch_id == num_microbatches - 1:
self.outstanding_range = (0, 0) self.outstanding_range = (0, 0)
with self.work_list_condition_lock:
self.work_list_condition_lock.wait_for(lambda: target_key in self.work_list)
return target_key return target_key

View File

@ -57,7 +57,6 @@ def split_batch(batch: Any, start, stop, device: str):
def type_detail(obj): def type_detail(obj):
return pytree_map(obj, lambda x: type(x), map_all=True) return pytree_map(obj, lambda x: type(x), map_all=True)
def pytree_filter(fn, obj, process_types): def pytree_filter(fn, obj, process_types):
if obj is None: if obj is None:
return None return None

View File

@ -31,7 +31,7 @@ class MLP(nn.Module):
def forward(self, x): def forward(self, x):
for layer in self.layers: for layer in self.layers:
x = layer(x) x = layer(x)
return x return x.sum()
class DAG_MLP(nn.Module): class DAG_MLP(nn.Module):
def __init__(self, dim: int, layers: int): def __init__(self, dim: int, layers: int):
@ -46,7 +46,7 @@ class DAG_MLP(nn.Module):
for layer in self.layers: for layer in self.layers:
x = layer(x) x = layer(x)
y = self.dag_layer(y) y = self.dag_layer(y)
return x, y return x.sum(), y.sum()
class RpcTestModel(nn.Module): class RpcTestModel(nn.Module):

View File

@ -41,10 +41,10 @@ def partition(model, data_kwargs: dict, pp_rank: int, chunk: int, stage_num: int
partition = create_partition_module(pp_rank, stage_num, model, data_kwargs) partition = create_partition_module(pp_rank, stage_num, model, data_kwargs)
return partition return partition
def run_master(model_cls, world_size): def run_master(model_cls, world_size, forward_only):
torch.manual_seed(100) torch.manual_seed(100)
epoch = 10 epoch = 3
device = 'cuda' device = 'cuda'
stage_num = world_size stage_num = world_size
chunk = 1 chunk = 1
@ -57,6 +57,10 @@ def run_master(model_cls, world_size):
kwargs = dict(x=x) kwargs = dict(x=x)
return kwargs return kwargs
model = model_cls(dim, stage_num * 3) model = model_cls(dim, stage_num * 3)
if forward_only:
labels = None
else:
labels = 1
elif model_cls == DAG_MLP: elif model_cls == DAG_MLP:
def data_gen(): def data_gen():
x = torch.zeros((batch_size, dim)) x = torch.zeros((batch_size, dim))
@ -64,24 +68,30 @@ def run_master(model_cls, world_size):
kwargs = dict(x=x, y=y) kwargs = dict(x=x, y=y)
return kwargs return kwargs
model = model_cls(dim, stage_num * 3) model = model_cls(dim, stage_num * 3)
if forward_only:
labels = None
else:
labels = 1
else: else:
pass pass
data_kwargs = data_gen() data_kwargs = data_gen()
engine = OneFOneBPipelineEngine(partition_fn=partial(partition, model, data_kwargs), engine = OneFOneBPipelineEngine(partition_fn=partial(partition, model, data_kwargs),
stage_num=stage_num, stage_num=stage_num,
num_microbatches=num_microbatches, num_microbatches=num_microbatches,
device=device, device=device,
chunk=chunk, chunk=chunk,
checkpoint=use_checkpoint,) checkpoint=use_checkpoint,)
if not forward_only:
engine.initialize_optimizer(getattr(torch.optim, 'SGD'), lr=1e-3)
for _ in range(epoch): for _ in range(epoch):
input_x = torch.randn((batch_size, dim), device=device) input_x = torch.randn((batch_size, dim), device=device)
input_y = torch.randn((batch_size, dim), device=device) input_y = torch.randn((batch_size, dim), device=device)
logits = engine.forward_backward({'x': input_x, 'y': input_y}, forward_only=True) logits = engine.forward_backward({'x': input_x, 'y': input_y}, labels=labels, forward_only=forward_only)
def run_worker(rank, model_cls, world_size, master_func): def run_worker(rank, model_cls, world_size, forward_only, master_func):
master_addr = 'localhost' master_addr = 'localhost'
master_port = 29020 master_port = 29020
os.environ['MASTER_ADDR'] = master_addr os.environ['MASTER_ADDR'] = master_addr
@ -99,19 +109,20 @@ def run_worker(rank, model_cls, world_size, master_func):
# in rpc mode, only rank 0 is needed to be coded # in rpc mode, only rank 0 is needed to be coded
if rank == 0: if rank == 0:
master_func(model_cls, world_size) master_func(model_cls, world_size, forward_only)
# barrier here # barrier here
if rpc_is_initialized(): if rpc_is_initialized():
rpc.shutdown() rpc.shutdown()
@pytest.mark.skip("skip due to CI torch version 1.11") @pytest.mark.skip("skip due to CI torch version 1.11")
@parameterize('model_cls', [MLP, DAG_MLP]) @parameterize('model_cls', [MLP, DAG_MLP])
@parameterize('forward_only', [True, False])
@pytest.mark.dist @pytest.mark.dist
@rerun_if_address_is_in_use() @rerun_if_address_is_in_use()
def test_pp_middleware_fwd(model_cls): def test_pp_middleware_fwd(model_cls, forward_only):
world_size = 4 world_size = 4
master_func = run_master master_func = run_master
mp.spawn(run_worker, args=(model_cls, world_size, master_func), nprocs=world_size) mp.spawn(run_worker, args=(model_cls, world_size, forward_only, master_func), nprocs=world_size)
if __name__ == "__main__": if __name__ == "__main__":
test_pp_middleware_fwd() test_pp_middleware_fwd()