|
|
|
@@ -8,20 +8,29 @@ from typing import Any, Callable, Dict, List, Tuple
|
|
|
|
|
|
|
|
|
|
import torch
|
|
|
|
|
import torch.distributed.rpc as rpc
|
|
|
|
|
from colossalai.pipeline.pipeline_process_group import ppg
|
|
|
|
|
from colossalai.pipeline.rpc.utils import (get_batch_lengths, pytree_filter, pytree_map,
|
|
|
|
|
split_batch, tensor_shape_list, type_detail)
|
|
|
|
|
from colossalai.pipeline.middleware import Partition, PartitionInputVal, PartitionOutputVal, Topo
|
|
|
|
|
from torch import autograd, nn, optim
|
|
|
|
|
from torch._C._distributed_rpc import PyRRef
|
|
|
|
|
from torch.futures import Future
|
|
|
|
|
|
|
|
|
|
from colossalai.pipeline.middleware import Partition, PartitionInputVal, PartitionOutputVal, Topo
|
|
|
|
|
from colossalai.pipeline.pipeline_process_group import ppg
|
|
|
|
|
from colossalai.pipeline.rpc.utils import (
|
|
|
|
|
get_batch_lengths,
|
|
|
|
|
pytree_filter,
|
|
|
|
|
pytree_map,
|
|
|
|
|
split_batch,
|
|
|
|
|
tensor_shape_list,
|
|
|
|
|
type_detail,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class Phase(Enum):
|
|
|
|
|
FORWARD = 0
|
|
|
|
|
BACKWARD = 1
|
|
|
|
|
UPDATE = 2
|
|
|
|
|
INPUT = 3
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class UniqueKey:
|
|
|
|
|
__slots__ = ('microbatch_id', 'phase')
|
|
|
|
|
microbatch_id: int
|
|
|
|
@@ -134,6 +143,7 @@ class WorkerBase(ABC):
|
|
|
|
|
self.partition_args = partition_args
|
|
|
|
|
self.criterion = criterion
|
|
|
|
|
self.metric = metric
|
|
|
|
|
self.reset = False
|
|
|
|
|
|
|
|
|
|
# context to maintain loop
|
|
|
|
|
self._initialize_context_container()
|
|
|
|
@@ -164,6 +174,7 @@ class WorkerBase(ABC):
|
|
|
|
|
self.work_list_condition_lock = threading.Condition(threading.Lock())
|
|
|
|
|
self.output_list_condition_lock = threading.Condition(threading.Lock())
|
|
|
|
|
self.label_lock = threading.Condition(threading.Lock())
|
|
|
|
|
self.reset_condition = threading.Condition(threading.Lock())
|
|
|
|
|
|
|
|
|
|
def _initialize_partition(self):
|
|
|
|
|
partition_fn = self.partition_fn
|
|
|
|
@@ -182,20 +193,23 @@ class WorkerBase(ABC):
|
|
|
|
|
# construction of partition is executed after the registion of pp_rank_to_worker_rref
|
|
|
|
|
self._initialize_partition()
|
|
|
|
|
|
|
|
|
|
def get_output_by_key(self, key: UniqueKey, recv_rank=None) -> Any:
|
|
|
|
|
# res_use works for lifecycle counter,
|
|
|
|
|
# if ref_use is True, lifecycle won't add.
|
|
|
|
|
def get_output_by_key(self, key: UniqueKey, ref_use=False) -> Any:
|
|
|
|
|
with self.output_list_condition_lock:
|
|
|
|
|
self.output_list_condition_lock.wait_for(lambda: key in self.output_list)
|
|
|
|
|
output_work_item = self.output_list[key]
|
|
|
|
|
self.output_list.pop(key)
|
|
|
|
|
|
|
|
|
|
output_work_item.refcount += 1
|
|
|
|
|
self.output_list.pop(key)
|
|
|
|
|
|
|
|
|
|
if not ref_use:
|
|
|
|
|
output_work_item.refcount += 1
|
|
|
|
|
refcount = output_work_item.refcount
|
|
|
|
|
output = output_work_item.output
|
|
|
|
|
|
|
|
|
|
if output_work_item.phase != Phase.INPUT:
|
|
|
|
|
if output_work_item.phase == Phase.FORWARD:
|
|
|
|
|
# lifecycle management for DAG scheduler
|
|
|
|
|
lifecycle = len(self.get_consumer_stage_ids())
|
|
|
|
|
if self.is_model_output(): # an extra reference for scheduler collecting results
|
|
|
|
|
if self.is_model_output(): # an extra reference for scheduler collecting results
|
|
|
|
|
lifecycle += 1
|
|
|
|
|
with self.output_list_condition_lock:
|
|
|
|
|
# all consumers have been satisfied, the work_item can be released
|
|
|
|
@@ -203,14 +217,24 @@ class WorkerBase(ABC):
|
|
|
|
|
if refcount < lifecycle:
|
|
|
|
|
self.output_list[key] = output_work_item
|
|
|
|
|
self.output_list_condition_lock.notify_all()
|
|
|
|
|
elif output_work_item.phase == Phase.BACKWARD:
|
|
|
|
|
lifecycle = len(self.get_producer_stage_ids())
|
|
|
|
|
if self._is_last_step(output_work_item):
|
|
|
|
|
lifecycle += 1 # an extra reference for scheduler collecting results
|
|
|
|
|
with self.output_list_condition_lock:
|
|
|
|
|
# all producers have been satisfied, the work_item can be released
|
|
|
|
|
# or put it into work list again.
|
|
|
|
|
if refcount < lifecycle:
|
|
|
|
|
self.output_list[key] = output_work_item
|
|
|
|
|
self.output_list_condition_lock.notify_all()
|
|
|
|
|
else:
|
|
|
|
|
with self.output_list_condition_lock:
|
|
|
|
|
self.output_list[key] = output_work_item
|
|
|
|
|
self.output_list_condition_lock.notify_all()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if isinstance(output, Future):
|
|
|
|
|
output = output.wait()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
return output
|
|
|
|
|
|
|
|
|
|
def get_parameters(self) -> List[torch.Tensor]:
|
|
|
|
@@ -257,13 +281,13 @@ class WorkerBase(ABC):
|
|
|
|
|
def set_input(self, microbatch_id: int, microbatch: Tuple[Any], forward_only: bool):
|
|
|
|
|
key = UniqueKey(microbatch_id, Phase.FORWARD)
|
|
|
|
|
output = self._get_future_by_device()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if not self.use_middleware():
|
|
|
|
|
# make args and kwargs
|
|
|
|
|
args, kwargs = self._make_args_kwargs(microbatch)
|
|
|
|
|
|
|
|
|
|
work_item = WorkItem(self.pp_rank, Phase.FORWARD, args, kwargs, output, microbatch_id, None,
|
|
|
|
|
self.num_microbatches, forward_only)
|
|
|
|
|
self.num_microbatches, forward_only)
|
|
|
|
|
with self.work_list_condition_lock:
|
|
|
|
|
self.work_list[key] = work_item
|
|
|
|
|
self.work_list_condition_lock.notify_all()
|
|
|
|
@@ -284,14 +308,14 @@ class WorkerBase(ABC):
|
|
|
|
|
self_arg_lst.append(arg_lst[off])
|
|
|
|
|
|
|
|
|
|
work_item = WorkItem(self.pp_rank, Phase.FORWARD, self_arg_lst, {}, output, microbatch_id, None,
|
|
|
|
|
self.num_microbatches, forward_only)
|
|
|
|
|
self.num_microbatches, forward_only)
|
|
|
|
|
with self.work_list_condition_lock:
|
|
|
|
|
self.work_list[key] = work_item
|
|
|
|
|
self.work_list_condition_lock.notify_all()
|
|
|
|
|
|
|
|
|
|
# put input tensor which other nodes need into output_list as Phase.INPUT
|
|
|
|
|
work_item_remote = WorkItem(self.pp_rank, Phase.INPUT, [], {}, arg_lst, microbatch_id, None,
|
|
|
|
|
self.num_microbatches, forward_only)
|
|
|
|
|
self.num_microbatches, forward_only)
|
|
|
|
|
|
|
|
|
|
with self.output_list_condition_lock:
|
|
|
|
|
self.output_list[recv_input_key] = work_item_remote
|
|
|
|
@@ -317,7 +341,7 @@ class WorkerBase(ABC):
|
|
|
|
|
|
|
|
|
|
self.work_list[key] = work_item
|
|
|
|
|
self.work_list_condition_lock.notify_all()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _subscribe_producer(self, microbatch_id: int, forward_only: bool):
|
|
|
|
|
"""
|
|
|
|
|
You should call this function asynchronously
|
|
|
|
@@ -336,7 +360,7 @@ class WorkerBase(ABC):
|
|
|
|
|
producer_stage_ids = self.get_producer_stage_ids()
|
|
|
|
|
producer_num = len(producer_stage_ids)
|
|
|
|
|
if self.need_model_input():
|
|
|
|
|
producer_num += 1 # for input partition
|
|
|
|
|
producer_num += 1 # for input partition
|
|
|
|
|
subscribe_forward_futures: List[Future] = [None] * producer_num
|
|
|
|
|
|
|
|
|
|
# TODO(jiangziyue) get single value instead of the whole output
|
|
|
|
@@ -344,26 +368,28 @@ class WorkerBase(ABC):
|
|
|
|
|
producer_stage_id = 0
|
|
|
|
|
producer_output_key = UniqueKey(microbatch_id, Phase.INPUT)
|
|
|
|
|
producer_worker_rref = self.pp_rank_to_worker_rref[producer_stage_id]
|
|
|
|
|
subscribe_forward_futures[0] = producer_worker_rref.rpc_async().get_output_by_key(producer_output_key, self.pp_rank)
|
|
|
|
|
subscribe_forward_futures[0] = producer_worker_rref.rpc_async().get_output_by_key(producer_output_key)
|
|
|
|
|
|
|
|
|
|
for i in range(0, producer_num-1):
|
|
|
|
|
for i in range(0, producer_num - 1):
|
|
|
|
|
producer_stage_id = producer_stage_ids[i]
|
|
|
|
|
producer_output_key = UniqueKey(microbatch_id, Phase.FORWARD)
|
|
|
|
|
producer_worker_rref = self.pp_rank_to_worker_rref[producer_stage_id]
|
|
|
|
|
subscribe_forward_futures[i+1] = producer_worker_rref.rpc_async().get_output_by_key(producer_output_key, self.pp_rank)
|
|
|
|
|
subscribe_forward_futures[i + 1] = producer_worker_rref.rpc_async().get_output_by_key(
|
|
|
|
|
producer_output_key)
|
|
|
|
|
|
|
|
|
|
else:
|
|
|
|
|
for i in range(producer_num):
|
|
|
|
|
producer_stage_id = producer_stage_ids[i]
|
|
|
|
|
producer_output_key = UniqueKey(microbatch_id, Phase.FORWARD)
|
|
|
|
|
producer_worker_rref = self.pp_rank_to_worker_rref[producer_stage_id]
|
|
|
|
|
subscribe_forward_futures[i] = producer_worker_rref.rpc_async().get_output_by_key(producer_output_key, self.pp_rank)
|
|
|
|
|
subscribe_forward_futures[i] = producer_worker_rref.rpc_async().get_output_by_key(
|
|
|
|
|
producer_output_key)
|
|
|
|
|
|
|
|
|
|
work_item_from_producer = WorkItem(stage_id, Phase.FORWARD, subscribe_forward_futures, {}, output,
|
|
|
|
|
microbatch_id, None, self.num_microbatches, forward_only)
|
|
|
|
|
|
|
|
|
|
microbatch_id, None, self.num_microbatches, forward_only)
|
|
|
|
|
|
|
|
|
|
return work_item_from_producer
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# TODO(jiangziyue) Profile the side effect of the lock for lifecycle protection and consider a better one.
|
|
|
|
|
def subscribe_producer(self, microbatch_id: int, forward_only: bool):
|
|
|
|
|
key = UniqueKey(microbatch_id, Phase.FORWARD)
|
|
|
|
@@ -377,20 +403,20 @@ class WorkerBase(ABC):
|
|
|
|
|
self.work_list[key] = work_item_from_producer
|
|
|
|
|
self.work_list_condition_lock.notify_all()
|
|
|
|
|
|
|
|
|
|
def subscribe_consumer(self, microbatch_id: int):
|
|
|
|
|
def _subscribe_consumer(self, microbatch_id: int):
|
|
|
|
|
"""
|
|
|
|
|
You should call this function asynchronously
|
|
|
|
|
"""
|
|
|
|
|
assert self.producer_stage_ids is not None
|
|
|
|
|
consumer_num = len(self.consumer_stage_ids)
|
|
|
|
|
assert consumer_num > 0, "only stage that has consumers can subscribe comsumers"
|
|
|
|
|
|
|
|
|
|
stage_id = self.pp_rank
|
|
|
|
|
subscribe_backward_futures: List[Future] = [None] * consumer_num
|
|
|
|
|
output = self._get_future_by_device()
|
|
|
|
|
|
|
|
|
|
if not self.use_middleware():
|
|
|
|
|
consumer_stage_ids = self.consumer_stage_ids
|
|
|
|
|
else:
|
|
|
|
|
consumer_stage_ids = self.get_consumer_stage_ids()
|
|
|
|
|
consumer_num = len(consumer_stage_ids)
|
|
|
|
|
subscribe_backward_futures: List[Future] = [None] * consumer_num
|
|
|
|
|
for i in range(consumer_num):
|
|
|
|
|
consumer_stage_id = self.consumer_stage_ids[i]
|
|
|
|
|
consumer_stage_id = consumer_stage_ids[i]
|
|
|
|
|
consumer_output_key = UniqueKey(microbatch_id, Phase.BACKWARD)
|
|
|
|
|
consumer_worker_rref = self.pp_rank_to_worker_rref[consumer_stage_id]
|
|
|
|
|
subscribe_backward_futures[i] = consumer_worker_rref.rpc_async().get_output_by_key(consumer_output_key)
|
|
|
|
@@ -399,13 +425,20 @@ class WorkerBase(ABC):
|
|
|
|
|
work_item_from_consumer = WorkItem(stage_id, Phase.BACKWARD, subscribe_backward_futures, {}, output,
|
|
|
|
|
microbatch_id, None, self.num_microbatches, False)
|
|
|
|
|
|
|
|
|
|
# add work_item to work_list
|
|
|
|
|
return work_item_from_consumer
|
|
|
|
|
|
|
|
|
|
def subscribe_consumer(self, microbatch_id: int):
|
|
|
|
|
key = UniqueKey(microbatch_id, Phase.BACKWARD)
|
|
|
|
|
with self.work_list_condition_lock:
|
|
|
|
|
key = UniqueKey(microbatch_id, Phase.BACKWARD)
|
|
|
|
|
assert key not in self.work_list
|
|
|
|
|
self.work_list[key] = work_item_from_consumer
|
|
|
|
|
self.work_list_condition_lock.notify_all()
|
|
|
|
|
|
|
|
|
|
if key not in self.work_list:
|
|
|
|
|
# On current PP middleware design for DAG, get_output_by_key used by subscribe_consumer
|
|
|
|
|
# can only be executed once for every producer-consumer stage pair, which is necessary
|
|
|
|
|
# to count the lifecycle of work_item. So, keeping the subscribe_consumer in the same
|
|
|
|
|
# lock of work_item queue operation gurantees the consistency of lifecycle counter.
|
|
|
|
|
work_item_from_consumer = self._subscribe_consumer(microbatch_id)
|
|
|
|
|
self.work_list[key] = work_item_from_consumer
|
|
|
|
|
self.work_list_condition_lock.notify_all()
|
|
|
|
|
|
|
|
|
|
def get_producer_stage_ids(self):
|
|
|
|
|
producer_stage_ids = []
|
|
|
|
|
rank = self.pp_rank
|
|
|
|
@@ -425,7 +458,7 @@ class WorkerBase(ABC):
|
|
|
|
|
if partition_id != model_input_partition_id:
|
|
|
|
|
producer_stage_ids.append(self.partition_id_to_pp_rank(partition_id, topo))
|
|
|
|
|
return producer_stage_ids
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def get_consumer_stage_ids(self):
|
|
|
|
|
consumer_stage_ids = []
|
|
|
|
|
rank = self.pp_rank
|
|
|
|
@@ -462,7 +495,7 @@ class WorkerBase(ABC):
|
|
|
|
|
for i, id in enumerate(partition_ids):
|
|
|
|
|
if id == partition_id:
|
|
|
|
|
return i
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def get_topo(self):
|
|
|
|
|
with self.partition_condition_lock:
|
|
|
|
|
self.partition_condition_lock.wait_for(lambda: hasattr(self, 'module_partition'))
|
|
|
|
@@ -470,13 +503,13 @@ class WorkerBase(ABC):
|
|
|
|
|
return self.module_partition._topo
|
|
|
|
|
else:
|
|
|
|
|
return None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def use_middleware(self):
|
|
|
|
|
topo = self.get_topo()
|
|
|
|
|
return topo is not None
|
|
|
|
|
|
|
|
|
|
# TODO(jiangziyue) get single value instead of the whole output
|
|
|
|
|
def _get_real_args_kwargs(self, args_or_kwargs):
|
|
|
|
|
def _get_real_args_kwargs_fwd(self, args_or_kwargs):
|
|
|
|
|
if not self.use_middleware():
|
|
|
|
|
args_or_kwargs = pytree_map(args_or_kwargs, fn=lambda x: x.wait(), process_types=Future)
|
|
|
|
|
if args_or_kwargs is not None:
|
|
|
|
@@ -491,8 +524,8 @@ class WorkerBase(ABC):
|
|
|
|
|
if args_or_kwargs is not None:
|
|
|
|
|
if isinstance(args_or_kwargs, dict):
|
|
|
|
|
pass
|
|
|
|
|
else:
|
|
|
|
|
flatten_args = []
|
|
|
|
|
else:
|
|
|
|
|
flatten_args = []
|
|
|
|
|
if self.is_first_stage():
|
|
|
|
|
pytree_map(args_or_kwargs, fn=lambda x: flatten_args.append(x), map_all=True)
|
|
|
|
|
# TODO get by offset
|
|
|
|
@@ -525,7 +558,7 @@ class WorkerBase(ABC):
|
|
|
|
|
if stage_id == src_stage_id:
|
|
|
|
|
src_index += i
|
|
|
|
|
break
|
|
|
|
|
else: # data from input partition
|
|
|
|
|
else: # data from input partition
|
|
|
|
|
src_index = 0
|
|
|
|
|
# when output_len = 1, not iterable
|
|
|
|
|
if output_len == 1:
|
|
|
|
@@ -536,6 +569,55 @@ class WorkerBase(ABC):
|
|
|
|
|
args_or_kwargs = flatten_args
|
|
|
|
|
return args_or_kwargs
|
|
|
|
|
|
|
|
|
|
# TODO(jiangziyue) get single value instead of the whole output
|
|
|
|
|
def _get_real_args_kwargs_bwd(self, args_or_kwargs):
|
|
|
|
|
if not self.use_middleware():
|
|
|
|
|
args_or_kwargs = pytree_map(args_or_kwargs, fn=lambda x: x.wait(), process_types=Future)
|
|
|
|
|
if args_or_kwargs is not None:
|
|
|
|
|
if isinstance(args_or_kwargs, dict):
|
|
|
|
|
pass
|
|
|
|
|
else:
|
|
|
|
|
flatten_args = []
|
|
|
|
|
pytree_map(args_or_kwargs, fn=lambda x: flatten_args.append(x), map_all=True)
|
|
|
|
|
args_or_kwargs = flatten_args
|
|
|
|
|
else:
|
|
|
|
|
args_or_kwargs = pytree_map(args_or_kwargs, fn=lambda x: x.wait(), process_types=Future)
|
|
|
|
|
if args_or_kwargs is not None:
|
|
|
|
|
flatten_args = []
|
|
|
|
|
# TODO get by offset
|
|
|
|
|
topo: Topo = self.get_topo()
|
|
|
|
|
self_partition_id = self.pp_rank_to_partition_id(self.pp_rank, topo)
|
|
|
|
|
self_partition: Partition = topo.get_partition_by_id(self_partition_id)
|
|
|
|
|
output_vals = self_partition.get_output_vals()
|
|
|
|
|
consumer_stage_ids = self.get_consumer_stage_ids()
|
|
|
|
|
for val_list in output_vals:
|
|
|
|
|
# An output may be passed to many down stages.
|
|
|
|
|
target = None
|
|
|
|
|
for val_pos in val_list.get():
|
|
|
|
|
dst_partition_id = val_pos.partition_id
|
|
|
|
|
dst_offset = val_pos.offset
|
|
|
|
|
dst_partition = topo.get_partition_by_id(dst_partition_id)
|
|
|
|
|
input_len = len(dst_partition.get_input_vals())
|
|
|
|
|
dst_stage_id = self.partition_id_to_pp_rank(dst_partition_id, topo)
|
|
|
|
|
for i, stage_id in enumerate(consumer_stage_ids):
|
|
|
|
|
if stage_id == dst_stage_id:
|
|
|
|
|
dst_index = i
|
|
|
|
|
break
|
|
|
|
|
if input_len == 1:
|
|
|
|
|
part_grad = args_or_kwargs[dst_index]
|
|
|
|
|
else:
|
|
|
|
|
part_grad = args_or_kwargs[dst_index][dst_offset]
|
|
|
|
|
|
|
|
|
|
if target is None:
|
|
|
|
|
target = part_grad
|
|
|
|
|
elif part_grad is not None:
|
|
|
|
|
target += part_grad
|
|
|
|
|
else:
|
|
|
|
|
continue
|
|
|
|
|
flatten_args.append(target)
|
|
|
|
|
args_or_kwargs = flatten_args
|
|
|
|
|
return args_or_kwargs
|
|
|
|
|
|
|
|
|
|
@abstractmethod
|
|
|
|
|
def _get_work_item_key(self) -> UniqueKey:
|
|
|
|
|
"""
|
|
|
|
@@ -547,7 +629,7 @@ class WorkerBase(ABC):
|
|
|
|
|
|
|
|
|
|
def is_last_stage(self):
|
|
|
|
|
return self.pp_rank == self.actual_stage_num - 1
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def need_model_input(self):
|
|
|
|
|
need_input = False
|
|
|
|
|
topo: Topo = self.get_topo()
|
|
|
|
@@ -558,10 +640,13 @@ class WorkerBase(ABC):
|
|
|
|
|
if model_input_partition_id in partition_inputs:
|
|
|
|
|
need_input = True
|
|
|
|
|
return not self.is_first_stage() and need_input
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def is_model_output(self):
|
|
|
|
|
return self.is_last_stage()
|
|
|
|
|
|
|
|
|
|
def is_model_input(self):
|
|
|
|
|
return self.is_first_stage()
|
|
|
|
|
|
|
|
|
|
def _default_data_process_func(self, args_kwargs):
|
|
|
|
|
if self.is_first_stage():
|
|
|
|
|
args = args_kwargs[0]
|
|
|
|
@@ -598,11 +683,16 @@ class WorkerBase(ABC):
|
|
|
|
|
|
|
|
|
|
# parse and integrate args and kwargs
|
|
|
|
|
if is_first_stage:
|
|
|
|
|
args = self._get_real_args_kwargs(args)
|
|
|
|
|
kwargs = self._get_real_args_kwargs(kwargs)
|
|
|
|
|
args = self._get_real_args_kwargs_fwd(args)
|
|
|
|
|
kwargs = self._get_real_args_kwargs_fwd(kwargs)
|
|
|
|
|
args_kwargs = (args, kwargs)
|
|
|
|
|
else:
|
|
|
|
|
args_kwargs = self._get_real_args_kwargs(args)
|
|
|
|
|
args_kwargs = self._get_real_args_kwargs_fwd(args)
|
|
|
|
|
|
|
|
|
|
if not forward_only:
|
|
|
|
|
pytree_map(args_kwargs,
|
|
|
|
|
lambda x: x.requires_grad_(True) if torch.is_floating_point(x) else x.requires_grad_(False),
|
|
|
|
|
process_types=torch.Tensor)
|
|
|
|
|
|
|
|
|
|
args, kwargs = data_process_func(args_kwargs)
|
|
|
|
|
|
|
|
|
@@ -694,21 +784,40 @@ class WorkerBase(ABC):
|
|
|
|
|
|
|
|
|
|
# overlap recompute and future.wait
|
|
|
|
|
if not is_last_stage:
|
|
|
|
|
grad_tensors = self._get_real_args_kwargs(args)
|
|
|
|
|
grad_tensors = self._get_real_args_kwargs_bwd(args)
|
|
|
|
|
else:
|
|
|
|
|
grad_tensors = None
|
|
|
|
|
|
|
|
|
|
# take tensor only (for only tensor can do backward)
|
|
|
|
|
stage_outputs = pytree_filter(lambda x: x.requires_grad, stage_outputs, process_types=torch.Tensor)
|
|
|
|
|
grad_tensors = pytree_filter(lambda x: x is not None, grad_tensors, process_types=torch.Tensor)
|
|
|
|
|
# TODO(jiangziyue) : All values which should do bp are torch.Tensor?
|
|
|
|
|
stage_outputs = pytree_filter(lambda x: True, stage_outputs, process_types=torch.Tensor)
|
|
|
|
|
grad_tensors = pytree_filter(lambda x: True, grad_tensors, process_types=torch.Tensor)
|
|
|
|
|
|
|
|
|
|
# output all input's grad to producer, even it has no grad(output None)
|
|
|
|
|
# to make the offset aligned to the topo's record.
|
|
|
|
|
if grad_tensors is not None:
|
|
|
|
|
filtered_outputs = []
|
|
|
|
|
filtered_grads = []
|
|
|
|
|
for i, grad in enumerate(grad_tensors):
|
|
|
|
|
stage_output = stage_outputs[i]
|
|
|
|
|
if stage_output.requires_grad and grad is not None:
|
|
|
|
|
filtered_outputs.append(stage_output)
|
|
|
|
|
filtered_grads.append(grad)
|
|
|
|
|
|
|
|
|
|
stage_outputs = filtered_outputs
|
|
|
|
|
grad_tensors = filtered_grads
|
|
|
|
|
|
|
|
|
|
autograd.backward(stage_outputs, grad_tensors=grad_tensors)
|
|
|
|
|
|
|
|
|
|
# collect grad of input tensor
|
|
|
|
|
consume_result = []
|
|
|
|
|
if not is_first_stage:
|
|
|
|
|
pytree_map(stage_input_args, lambda x: consume_result.append(x.grad), process_types=torch.Tensor)
|
|
|
|
|
pytree_map(stage_input_kwargs, lambda x: consume_result.append(x.grad), process_types=torch.Tensor)
|
|
|
|
|
# In current design, input mush be a flatten args.
|
|
|
|
|
for arg in stage_input_args:
|
|
|
|
|
if isinstance(arg, torch.Tensor):
|
|
|
|
|
consume_result.append(arg.grad)
|
|
|
|
|
else:
|
|
|
|
|
consume_result.append(None)
|
|
|
|
|
|
|
|
|
|
else:
|
|
|
|
|
raise TypeError(f"Unknown phase appears in _consume_work_item_by_phase {phase}")
|
|
|
|
@@ -740,11 +849,11 @@ class WorkerBase(ABC):
|
|
|
|
|
def _hook_before_step(self):
|
|
|
|
|
pass
|
|
|
|
|
|
|
|
|
|
def _reset_context(self):
|
|
|
|
|
self.forward_times = 0
|
|
|
|
|
self.backward_times = 0
|
|
|
|
|
self.outstanding = 0
|
|
|
|
|
self._initialize_outstanding_range()
|
|
|
|
|
# install the main loop to wait for next batch input
|
|
|
|
|
def _wait_for_reset(self):
|
|
|
|
|
with self.reset_condition:
|
|
|
|
|
self.reset_condition.wait_for(lambda: self.reset)
|
|
|
|
|
self.reset = False
|
|
|
|
|
|
|
|
|
|
# do the main loop to consume ready_list
|
|
|
|
|
def _work_loop(self):
|
|
|
|
@@ -755,10 +864,9 @@ class WorkerBase(ABC):
|
|
|
|
|
# main loop
|
|
|
|
|
while True:
|
|
|
|
|
work_item_key = self._get_work_item_key()
|
|
|
|
|
|
|
|
|
|
# move current work item to output_list to activate subscribe in advance
|
|
|
|
|
with self.work_list_condition_lock:
|
|
|
|
|
#self.work_list_condition_lock.wait_for(lambda: work_item_key in self.work_list)
|
|
|
|
|
self.work_list_condition_lock.wait_for(lambda: work_item_key in self.work_list)
|
|
|
|
|
work_item = self.work_list[work_item_key]
|
|
|
|
|
|
|
|
|
|
with self.output_list_condition_lock:
|
|
|
|
@@ -768,16 +876,32 @@ class WorkerBase(ABC):
|
|
|
|
|
|
|
|
|
|
consume_result = self._consume_work_item_by_phase(work_item)
|
|
|
|
|
|
|
|
|
|
work_item.output.set_result(consume_result)
|
|
|
|
|
with self.work_list_condition_lock:
|
|
|
|
|
self.work_list.pop(work_item_key)
|
|
|
|
|
work_item.output.set_result(consume_result)
|
|
|
|
|
|
|
|
|
|
# if is last step in one batch reset context and do step
|
|
|
|
|
if self._is_last_step(work_item):
|
|
|
|
|
self._hook_before_step()
|
|
|
|
|
if hasattr(self, 'optimizer') and not work_item.forward_only:
|
|
|
|
|
self.step()
|
|
|
|
|
self._reset_context()
|
|
|
|
|
self._wait_for_reset()
|
|
|
|
|
|
|
|
|
|
# reset context and resume loop
|
|
|
|
|
def reset_context(self):
|
|
|
|
|
self.forward_times = 0
|
|
|
|
|
self.backward_times = 0
|
|
|
|
|
self.outstanding = 0
|
|
|
|
|
self._initialize_outstanding_range()
|
|
|
|
|
with self.work_list_condition_lock:
|
|
|
|
|
self.work_list.clear()
|
|
|
|
|
|
|
|
|
|
with self.output_list_condition_lock:
|
|
|
|
|
self.output_list.clear()
|
|
|
|
|
|
|
|
|
|
with self.reset_condition:
|
|
|
|
|
self.reset = True
|
|
|
|
|
self.reset_condition.notify_all()
|
|
|
|
|
|
|
|
|
|
def initialize_optimizer(self, optimizer_class: type, **kwargs):
|
|
|
|
|
# TODO(jiangziyue) it's temporary code to deal with empty module partition.
|
|
|
|
@@ -856,7 +980,7 @@ class PipelineEngineBase(ABC, nn.Module):
|
|
|
|
|
|
|
|
|
|
def _create_pp_rank_to_rpc_worker_id(self) -> None:
|
|
|
|
|
"""create a map from model partition to stage_id, which is useful when use_interleave is True.
|
|
|
|
|
e.g. If a model is splited into 4 parts, which means stage_num is 2, chunk is 2, then
|
|
|
|
|
e.g. If a model is splited into 4 parts, which means stage_num is 2, chunk is 2, then
|
|
|
|
|
pp_rank_to_rpc_worker_id = [0, 1, 0, 1], that means first and third part
|
|
|
|
|
of partitions will be moved to device 0 and the others to device 1
|
|
|
|
|
"""
|
|
|
|
@@ -947,7 +1071,7 @@ class PipelineEngineBase(ABC, nn.Module):
|
|
|
|
|
key = UniqueKey(microbatch_id - actual_stage_num, Phase.BACKWARD)
|
|
|
|
|
for pp_rank in input_pp_ranks:
|
|
|
|
|
worker_rref = self.pp_rank_to_worker_rref[pp_rank]
|
|
|
|
|
worker_rref.rpc_sync().get_output_by_key(key)
|
|
|
|
|
worker_rref.rpc_sync().get_output_by_key(key, ref_use=True)
|
|
|
|
|
|
|
|
|
|
def _create_ret_future(self, output_pp_ranks: List[int]) -> Dict[int, List[Future]]:
|
|
|
|
|
num_microbatches = self.num_microbatches
|
|
|
|
@@ -965,6 +1089,7 @@ class PipelineEngineBase(ABC, nn.Module):
|
|
|
|
|
# TODO : add relationship between output_pp_ranks and parts of microlabels
|
|
|
|
|
worker_rref.remote().set_labels(microbatch_id, microlabels)
|
|
|
|
|
|
|
|
|
|
# TODO(jiangziyue) : get model output with single value, instead of merging into last stage.
|
|
|
|
|
def _subscribe_forward(self, microbatch_id: int, output_pp_ranks: List[int], ret_future: Dict[int, List[Future]]):
|
|
|
|
|
key = UniqueKey(microbatch_id, Phase.FORWARD)
|
|
|
|
|
for pp_rank in output_pp_ranks:
|
|
|
|
@@ -993,6 +1118,16 @@ class PipelineEngineBase(ABC, nn.Module):
|
|
|
|
|
|
|
|
|
|
return forward_result
|
|
|
|
|
|
|
|
|
|
def _reset_worker(self):
|
|
|
|
|
actual_stage_num = self._get_actual_stage_num()
|
|
|
|
|
for pp_rank in range(actual_stage_num):
|
|
|
|
|
worker_rref = self.pp_rank_to_worker_rref[pp_rank]
|
|
|
|
|
fut = worker_rref.rpc_async().reset_context()
|
|
|
|
|
self.step_futs.append(fut)
|
|
|
|
|
|
|
|
|
|
for fut in self.step_futs:
|
|
|
|
|
fut.wait()
|
|
|
|
|
|
|
|
|
|
def forward_backward(self, batch: torch.Tensor, labels: torch.Tensor = None, forward_only: bool = False):
|
|
|
|
|
batch_lengths = get_batch_lengths(batch)
|
|
|
|
|
batch_length = batch_lengths[0]
|
|
|
|
@@ -1046,6 +1181,7 @@ class PipelineEngineBase(ABC, nn.Module):
|
|
|
|
|
worker_rref = self.pp_rank_to_worker_rref[pp_rank]
|
|
|
|
|
worker_rref.rpc_sync().wait_for_step()
|
|
|
|
|
|
|
|
|
|
self._reset_worker() # reset worker attributes for next batch
|
|
|
|
|
return forward_result
|
|
|
|
|
|
|
|
|
|
def initialize_optimizer(self, optimizer_class: type, **kwargs):
|
|
|
|
|