mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-08-07 11:03:58 +00:00
[PP Middleware] Add bwd and step for PP middleware (#2111)
* add bwd and step for PP middleware * pre-commit Co-authored-by: Ziyue Jiang <ziyue.jiang@gmail.com>
This commit is contained in:
parent
8afc001f4f
commit
09d69e1c25
@ -8,20 +8,29 @@ from typing import Any, Callable, Dict, List, Tuple
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed.rpc as rpc
|
import torch.distributed.rpc as rpc
|
||||||
from colossalai.pipeline.pipeline_process_group import ppg
|
|
||||||
from colossalai.pipeline.rpc.utils import (get_batch_lengths, pytree_filter, pytree_map,
|
|
||||||
split_batch, tensor_shape_list, type_detail)
|
|
||||||
from colossalai.pipeline.middleware import Partition, PartitionInputVal, PartitionOutputVal, Topo
|
|
||||||
from torch import autograd, nn, optim
|
from torch import autograd, nn, optim
|
||||||
from torch._C._distributed_rpc import PyRRef
|
from torch._C._distributed_rpc import PyRRef
|
||||||
from torch.futures import Future
|
from torch.futures import Future
|
||||||
|
|
||||||
|
from colossalai.pipeline.middleware import Partition, PartitionInputVal, PartitionOutputVal, Topo
|
||||||
|
from colossalai.pipeline.pipeline_process_group import ppg
|
||||||
|
from colossalai.pipeline.rpc.utils import (
|
||||||
|
get_batch_lengths,
|
||||||
|
pytree_filter,
|
||||||
|
pytree_map,
|
||||||
|
split_batch,
|
||||||
|
tensor_shape_list,
|
||||||
|
type_detail,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class Phase(Enum):
|
class Phase(Enum):
|
||||||
FORWARD = 0
|
FORWARD = 0
|
||||||
BACKWARD = 1
|
BACKWARD = 1
|
||||||
UPDATE = 2
|
UPDATE = 2
|
||||||
INPUT = 3
|
INPUT = 3
|
||||||
|
|
||||||
|
|
||||||
class UniqueKey:
|
class UniqueKey:
|
||||||
__slots__ = ('microbatch_id', 'phase')
|
__slots__ = ('microbatch_id', 'phase')
|
||||||
microbatch_id: int
|
microbatch_id: int
|
||||||
@ -134,6 +143,7 @@ class WorkerBase(ABC):
|
|||||||
self.partition_args = partition_args
|
self.partition_args = partition_args
|
||||||
self.criterion = criterion
|
self.criterion = criterion
|
||||||
self.metric = metric
|
self.metric = metric
|
||||||
|
self.reset = False
|
||||||
|
|
||||||
# context to maintain loop
|
# context to maintain loop
|
||||||
self._initialize_context_container()
|
self._initialize_context_container()
|
||||||
@ -164,6 +174,7 @@ class WorkerBase(ABC):
|
|||||||
self.work_list_condition_lock = threading.Condition(threading.Lock())
|
self.work_list_condition_lock = threading.Condition(threading.Lock())
|
||||||
self.output_list_condition_lock = threading.Condition(threading.Lock())
|
self.output_list_condition_lock = threading.Condition(threading.Lock())
|
||||||
self.label_lock = threading.Condition(threading.Lock())
|
self.label_lock = threading.Condition(threading.Lock())
|
||||||
|
self.reset_condition = threading.Condition(threading.Lock())
|
||||||
|
|
||||||
def _initialize_partition(self):
|
def _initialize_partition(self):
|
||||||
partition_fn = self.partition_fn
|
partition_fn = self.partition_fn
|
||||||
@ -182,20 +193,23 @@ class WorkerBase(ABC):
|
|||||||
# construction of partition is executed after the registion of pp_rank_to_worker_rref
|
# construction of partition is executed after the registion of pp_rank_to_worker_rref
|
||||||
self._initialize_partition()
|
self._initialize_partition()
|
||||||
|
|
||||||
def get_output_by_key(self, key: UniqueKey, recv_rank=None) -> Any:
|
# res_use works for lifecycle counter,
|
||||||
|
# if ref_use is True, lifecycle won't add.
|
||||||
|
def get_output_by_key(self, key: UniqueKey, ref_use=False) -> Any:
|
||||||
with self.output_list_condition_lock:
|
with self.output_list_condition_lock:
|
||||||
self.output_list_condition_lock.wait_for(lambda: key in self.output_list)
|
self.output_list_condition_lock.wait_for(lambda: key in self.output_list)
|
||||||
output_work_item = self.output_list[key]
|
output_work_item = self.output_list[key]
|
||||||
self.output_list.pop(key)
|
self.output_list.pop(key)
|
||||||
|
|
||||||
output_work_item.refcount += 1
|
if not ref_use:
|
||||||
|
output_work_item.refcount += 1
|
||||||
refcount = output_work_item.refcount
|
refcount = output_work_item.refcount
|
||||||
output = output_work_item.output
|
output = output_work_item.output
|
||||||
|
|
||||||
if output_work_item.phase != Phase.INPUT:
|
if output_work_item.phase == Phase.FORWARD:
|
||||||
# lifecycle management for DAG scheduler
|
# lifecycle management for DAG scheduler
|
||||||
lifecycle = len(self.get_consumer_stage_ids())
|
lifecycle = len(self.get_consumer_stage_ids())
|
||||||
if self.is_model_output(): # an extra reference for scheduler collecting results
|
if self.is_model_output(): # an extra reference for scheduler collecting results
|
||||||
lifecycle += 1
|
lifecycle += 1
|
||||||
with self.output_list_condition_lock:
|
with self.output_list_condition_lock:
|
||||||
# all consumers have been satisfied, the work_item can be released
|
# all consumers have been satisfied, the work_item can be released
|
||||||
@ -203,14 +217,24 @@ class WorkerBase(ABC):
|
|||||||
if refcount < lifecycle:
|
if refcount < lifecycle:
|
||||||
self.output_list[key] = output_work_item
|
self.output_list[key] = output_work_item
|
||||||
self.output_list_condition_lock.notify_all()
|
self.output_list_condition_lock.notify_all()
|
||||||
|
elif output_work_item.phase == Phase.BACKWARD:
|
||||||
|
lifecycle = len(self.get_producer_stage_ids())
|
||||||
|
if self._is_last_step(output_work_item):
|
||||||
|
lifecycle += 1 # an extra reference for scheduler collecting results
|
||||||
|
with self.output_list_condition_lock:
|
||||||
|
# all producers have been satisfied, the work_item can be released
|
||||||
|
# or put it into work list again.
|
||||||
|
if refcount < lifecycle:
|
||||||
|
self.output_list[key] = output_work_item
|
||||||
|
self.output_list_condition_lock.notify_all()
|
||||||
else:
|
else:
|
||||||
with self.output_list_condition_lock:
|
with self.output_list_condition_lock:
|
||||||
self.output_list[key] = output_work_item
|
self.output_list[key] = output_work_item
|
||||||
self.output_list_condition_lock.notify_all()
|
self.output_list_condition_lock.notify_all()
|
||||||
|
|
||||||
if isinstance(output, Future):
|
if isinstance(output, Future):
|
||||||
output = output.wait()
|
output = output.wait()
|
||||||
|
|
||||||
return output
|
return output
|
||||||
|
|
||||||
def get_parameters(self) -> List[torch.Tensor]:
|
def get_parameters(self) -> List[torch.Tensor]:
|
||||||
@ -257,13 +281,13 @@ class WorkerBase(ABC):
|
|||||||
def set_input(self, microbatch_id: int, microbatch: Tuple[Any], forward_only: bool):
|
def set_input(self, microbatch_id: int, microbatch: Tuple[Any], forward_only: bool):
|
||||||
key = UniqueKey(microbatch_id, Phase.FORWARD)
|
key = UniqueKey(microbatch_id, Phase.FORWARD)
|
||||||
output = self._get_future_by_device()
|
output = self._get_future_by_device()
|
||||||
|
|
||||||
if not self.use_middleware():
|
if not self.use_middleware():
|
||||||
# make args and kwargs
|
# make args and kwargs
|
||||||
args, kwargs = self._make_args_kwargs(microbatch)
|
args, kwargs = self._make_args_kwargs(microbatch)
|
||||||
|
|
||||||
work_item = WorkItem(self.pp_rank, Phase.FORWARD, args, kwargs, output, microbatch_id, None,
|
work_item = WorkItem(self.pp_rank, Phase.FORWARD, args, kwargs, output, microbatch_id, None,
|
||||||
self.num_microbatches, forward_only)
|
self.num_microbatches, forward_only)
|
||||||
with self.work_list_condition_lock:
|
with self.work_list_condition_lock:
|
||||||
self.work_list[key] = work_item
|
self.work_list[key] = work_item
|
||||||
self.work_list_condition_lock.notify_all()
|
self.work_list_condition_lock.notify_all()
|
||||||
@ -284,14 +308,14 @@ class WorkerBase(ABC):
|
|||||||
self_arg_lst.append(arg_lst[off])
|
self_arg_lst.append(arg_lst[off])
|
||||||
|
|
||||||
work_item = WorkItem(self.pp_rank, Phase.FORWARD, self_arg_lst, {}, output, microbatch_id, None,
|
work_item = WorkItem(self.pp_rank, Phase.FORWARD, self_arg_lst, {}, output, microbatch_id, None,
|
||||||
self.num_microbatches, forward_only)
|
self.num_microbatches, forward_only)
|
||||||
with self.work_list_condition_lock:
|
with self.work_list_condition_lock:
|
||||||
self.work_list[key] = work_item
|
self.work_list[key] = work_item
|
||||||
self.work_list_condition_lock.notify_all()
|
self.work_list_condition_lock.notify_all()
|
||||||
|
|
||||||
# put input tensor which other nodes need into output_list as Phase.INPUT
|
# put input tensor which other nodes need into output_list as Phase.INPUT
|
||||||
work_item_remote = WorkItem(self.pp_rank, Phase.INPUT, [], {}, arg_lst, microbatch_id, None,
|
work_item_remote = WorkItem(self.pp_rank, Phase.INPUT, [], {}, arg_lst, microbatch_id, None,
|
||||||
self.num_microbatches, forward_only)
|
self.num_microbatches, forward_only)
|
||||||
|
|
||||||
with self.output_list_condition_lock:
|
with self.output_list_condition_lock:
|
||||||
self.output_list[recv_input_key] = work_item_remote
|
self.output_list[recv_input_key] = work_item_remote
|
||||||
@ -317,7 +341,7 @@ class WorkerBase(ABC):
|
|||||||
|
|
||||||
self.work_list[key] = work_item
|
self.work_list[key] = work_item
|
||||||
self.work_list_condition_lock.notify_all()
|
self.work_list_condition_lock.notify_all()
|
||||||
|
|
||||||
def _subscribe_producer(self, microbatch_id: int, forward_only: bool):
|
def _subscribe_producer(self, microbatch_id: int, forward_only: bool):
|
||||||
"""
|
"""
|
||||||
You should call this function asynchronously
|
You should call this function asynchronously
|
||||||
@ -336,7 +360,7 @@ class WorkerBase(ABC):
|
|||||||
producer_stage_ids = self.get_producer_stage_ids()
|
producer_stage_ids = self.get_producer_stage_ids()
|
||||||
producer_num = len(producer_stage_ids)
|
producer_num = len(producer_stage_ids)
|
||||||
if self.need_model_input():
|
if self.need_model_input():
|
||||||
producer_num += 1 # for input partition
|
producer_num += 1 # for input partition
|
||||||
subscribe_forward_futures: List[Future] = [None] * producer_num
|
subscribe_forward_futures: List[Future] = [None] * producer_num
|
||||||
|
|
||||||
# TODO(jiangziyue) get single value instead of the whole output
|
# TODO(jiangziyue) get single value instead of the whole output
|
||||||
@ -344,26 +368,28 @@ class WorkerBase(ABC):
|
|||||||
producer_stage_id = 0
|
producer_stage_id = 0
|
||||||
producer_output_key = UniqueKey(microbatch_id, Phase.INPUT)
|
producer_output_key = UniqueKey(microbatch_id, Phase.INPUT)
|
||||||
producer_worker_rref = self.pp_rank_to_worker_rref[producer_stage_id]
|
producer_worker_rref = self.pp_rank_to_worker_rref[producer_stage_id]
|
||||||
subscribe_forward_futures[0] = producer_worker_rref.rpc_async().get_output_by_key(producer_output_key, self.pp_rank)
|
subscribe_forward_futures[0] = producer_worker_rref.rpc_async().get_output_by_key(producer_output_key)
|
||||||
|
|
||||||
for i in range(0, producer_num-1):
|
for i in range(0, producer_num - 1):
|
||||||
producer_stage_id = producer_stage_ids[i]
|
producer_stage_id = producer_stage_ids[i]
|
||||||
producer_output_key = UniqueKey(microbatch_id, Phase.FORWARD)
|
producer_output_key = UniqueKey(microbatch_id, Phase.FORWARD)
|
||||||
producer_worker_rref = self.pp_rank_to_worker_rref[producer_stage_id]
|
producer_worker_rref = self.pp_rank_to_worker_rref[producer_stage_id]
|
||||||
subscribe_forward_futures[i+1] = producer_worker_rref.rpc_async().get_output_by_key(producer_output_key, self.pp_rank)
|
subscribe_forward_futures[i + 1] = producer_worker_rref.rpc_async().get_output_by_key(
|
||||||
|
producer_output_key)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
for i in range(producer_num):
|
for i in range(producer_num):
|
||||||
producer_stage_id = producer_stage_ids[i]
|
producer_stage_id = producer_stage_ids[i]
|
||||||
producer_output_key = UniqueKey(microbatch_id, Phase.FORWARD)
|
producer_output_key = UniqueKey(microbatch_id, Phase.FORWARD)
|
||||||
producer_worker_rref = self.pp_rank_to_worker_rref[producer_stage_id]
|
producer_worker_rref = self.pp_rank_to_worker_rref[producer_stage_id]
|
||||||
subscribe_forward_futures[i] = producer_worker_rref.rpc_async().get_output_by_key(producer_output_key, self.pp_rank)
|
subscribe_forward_futures[i] = producer_worker_rref.rpc_async().get_output_by_key(
|
||||||
|
producer_output_key)
|
||||||
|
|
||||||
work_item_from_producer = WorkItem(stage_id, Phase.FORWARD, subscribe_forward_futures, {}, output,
|
work_item_from_producer = WorkItem(stage_id, Phase.FORWARD, subscribe_forward_futures, {}, output,
|
||||||
microbatch_id, None, self.num_microbatches, forward_only)
|
microbatch_id, None, self.num_microbatches, forward_only)
|
||||||
|
|
||||||
return work_item_from_producer
|
return work_item_from_producer
|
||||||
|
|
||||||
# TODO(jiangziyue) Profile the side effect of the lock for lifecycle protection and consider a better one.
|
# TODO(jiangziyue) Profile the side effect of the lock for lifecycle protection and consider a better one.
|
||||||
def subscribe_producer(self, microbatch_id: int, forward_only: bool):
|
def subscribe_producer(self, microbatch_id: int, forward_only: bool):
|
||||||
key = UniqueKey(microbatch_id, Phase.FORWARD)
|
key = UniqueKey(microbatch_id, Phase.FORWARD)
|
||||||
@ -377,20 +403,20 @@ class WorkerBase(ABC):
|
|||||||
self.work_list[key] = work_item_from_producer
|
self.work_list[key] = work_item_from_producer
|
||||||
self.work_list_condition_lock.notify_all()
|
self.work_list_condition_lock.notify_all()
|
||||||
|
|
||||||
def subscribe_consumer(self, microbatch_id: int):
|
def _subscribe_consumer(self, microbatch_id: int):
|
||||||
"""
|
"""
|
||||||
You should call this function asynchronously
|
You should call this function asynchronously
|
||||||
"""
|
"""
|
||||||
assert self.producer_stage_ids is not None
|
|
||||||
consumer_num = len(self.consumer_stage_ids)
|
|
||||||
assert consumer_num > 0, "only stage that has consumers can subscribe comsumers"
|
|
||||||
|
|
||||||
stage_id = self.pp_rank
|
stage_id = self.pp_rank
|
||||||
subscribe_backward_futures: List[Future] = [None] * consumer_num
|
|
||||||
output = self._get_future_by_device()
|
output = self._get_future_by_device()
|
||||||
|
if not self.use_middleware():
|
||||||
|
consumer_stage_ids = self.consumer_stage_ids
|
||||||
|
else:
|
||||||
|
consumer_stage_ids = self.get_consumer_stage_ids()
|
||||||
|
consumer_num = len(consumer_stage_ids)
|
||||||
|
subscribe_backward_futures: List[Future] = [None] * consumer_num
|
||||||
for i in range(consumer_num):
|
for i in range(consumer_num):
|
||||||
consumer_stage_id = self.consumer_stage_ids[i]
|
consumer_stage_id = consumer_stage_ids[i]
|
||||||
consumer_output_key = UniqueKey(microbatch_id, Phase.BACKWARD)
|
consumer_output_key = UniqueKey(microbatch_id, Phase.BACKWARD)
|
||||||
consumer_worker_rref = self.pp_rank_to_worker_rref[consumer_stage_id]
|
consumer_worker_rref = self.pp_rank_to_worker_rref[consumer_stage_id]
|
||||||
subscribe_backward_futures[i] = consumer_worker_rref.rpc_async().get_output_by_key(consumer_output_key)
|
subscribe_backward_futures[i] = consumer_worker_rref.rpc_async().get_output_by_key(consumer_output_key)
|
||||||
@ -399,13 +425,20 @@ class WorkerBase(ABC):
|
|||||||
work_item_from_consumer = WorkItem(stage_id, Phase.BACKWARD, subscribe_backward_futures, {}, output,
|
work_item_from_consumer = WorkItem(stage_id, Phase.BACKWARD, subscribe_backward_futures, {}, output,
|
||||||
microbatch_id, None, self.num_microbatches, False)
|
microbatch_id, None, self.num_microbatches, False)
|
||||||
|
|
||||||
# add work_item to work_list
|
return work_item_from_consumer
|
||||||
|
|
||||||
|
def subscribe_consumer(self, microbatch_id: int):
|
||||||
|
key = UniqueKey(microbatch_id, Phase.BACKWARD)
|
||||||
with self.work_list_condition_lock:
|
with self.work_list_condition_lock:
|
||||||
key = UniqueKey(microbatch_id, Phase.BACKWARD)
|
if key not in self.work_list:
|
||||||
assert key not in self.work_list
|
# On current PP middleware design for DAG, get_output_by_key used by subscribe_consumer
|
||||||
self.work_list[key] = work_item_from_consumer
|
# can only be executed once for every producer-consumer stage pair, which is necessary
|
||||||
self.work_list_condition_lock.notify_all()
|
# to count the lifecycle of work_item. So, keeping the subscribe_consumer in the same
|
||||||
|
# lock of work_item queue operation gurantees the consistency of lifecycle counter.
|
||||||
|
work_item_from_consumer = self._subscribe_consumer(microbatch_id)
|
||||||
|
self.work_list[key] = work_item_from_consumer
|
||||||
|
self.work_list_condition_lock.notify_all()
|
||||||
|
|
||||||
def get_producer_stage_ids(self):
|
def get_producer_stage_ids(self):
|
||||||
producer_stage_ids = []
|
producer_stage_ids = []
|
||||||
rank = self.pp_rank
|
rank = self.pp_rank
|
||||||
@ -425,7 +458,7 @@ class WorkerBase(ABC):
|
|||||||
if partition_id != model_input_partition_id:
|
if partition_id != model_input_partition_id:
|
||||||
producer_stage_ids.append(self.partition_id_to_pp_rank(partition_id, topo))
|
producer_stage_ids.append(self.partition_id_to_pp_rank(partition_id, topo))
|
||||||
return producer_stage_ids
|
return producer_stage_ids
|
||||||
|
|
||||||
def get_consumer_stage_ids(self):
|
def get_consumer_stage_ids(self):
|
||||||
consumer_stage_ids = []
|
consumer_stage_ids = []
|
||||||
rank = self.pp_rank
|
rank = self.pp_rank
|
||||||
@ -462,7 +495,7 @@ class WorkerBase(ABC):
|
|||||||
for i, id in enumerate(partition_ids):
|
for i, id in enumerate(partition_ids):
|
||||||
if id == partition_id:
|
if id == partition_id:
|
||||||
return i
|
return i
|
||||||
|
|
||||||
def get_topo(self):
|
def get_topo(self):
|
||||||
with self.partition_condition_lock:
|
with self.partition_condition_lock:
|
||||||
self.partition_condition_lock.wait_for(lambda: hasattr(self, 'module_partition'))
|
self.partition_condition_lock.wait_for(lambda: hasattr(self, 'module_partition'))
|
||||||
@ -470,13 +503,13 @@ class WorkerBase(ABC):
|
|||||||
return self.module_partition._topo
|
return self.module_partition._topo
|
||||||
else:
|
else:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def use_middleware(self):
|
def use_middleware(self):
|
||||||
topo = self.get_topo()
|
topo = self.get_topo()
|
||||||
return topo is not None
|
return topo is not None
|
||||||
|
|
||||||
# TODO(jiangziyue) get single value instead of the whole output
|
# TODO(jiangziyue) get single value instead of the whole output
|
||||||
def _get_real_args_kwargs(self, args_or_kwargs):
|
def _get_real_args_kwargs_fwd(self, args_or_kwargs):
|
||||||
if not self.use_middleware():
|
if not self.use_middleware():
|
||||||
args_or_kwargs = pytree_map(args_or_kwargs, fn=lambda x: x.wait(), process_types=Future)
|
args_or_kwargs = pytree_map(args_or_kwargs, fn=lambda x: x.wait(), process_types=Future)
|
||||||
if args_or_kwargs is not None:
|
if args_or_kwargs is not None:
|
||||||
@ -491,8 +524,8 @@ class WorkerBase(ABC):
|
|||||||
if args_or_kwargs is not None:
|
if args_or_kwargs is not None:
|
||||||
if isinstance(args_or_kwargs, dict):
|
if isinstance(args_or_kwargs, dict):
|
||||||
pass
|
pass
|
||||||
else:
|
else:
|
||||||
flatten_args = []
|
flatten_args = []
|
||||||
if self.is_first_stage():
|
if self.is_first_stage():
|
||||||
pytree_map(args_or_kwargs, fn=lambda x: flatten_args.append(x), map_all=True)
|
pytree_map(args_or_kwargs, fn=lambda x: flatten_args.append(x), map_all=True)
|
||||||
# TODO get by offset
|
# TODO get by offset
|
||||||
@ -525,7 +558,7 @@ class WorkerBase(ABC):
|
|||||||
if stage_id == src_stage_id:
|
if stage_id == src_stage_id:
|
||||||
src_index += i
|
src_index += i
|
||||||
break
|
break
|
||||||
else: # data from input partition
|
else: # data from input partition
|
||||||
src_index = 0
|
src_index = 0
|
||||||
# when output_len = 1, not iterable
|
# when output_len = 1, not iterable
|
||||||
if output_len == 1:
|
if output_len == 1:
|
||||||
@ -536,6 +569,55 @@ class WorkerBase(ABC):
|
|||||||
args_or_kwargs = flatten_args
|
args_or_kwargs = flatten_args
|
||||||
return args_or_kwargs
|
return args_or_kwargs
|
||||||
|
|
||||||
|
# TODO(jiangziyue) get single value instead of the whole output
|
||||||
|
def _get_real_args_kwargs_bwd(self, args_or_kwargs):
|
||||||
|
if not self.use_middleware():
|
||||||
|
args_or_kwargs = pytree_map(args_or_kwargs, fn=lambda x: x.wait(), process_types=Future)
|
||||||
|
if args_or_kwargs is not None:
|
||||||
|
if isinstance(args_or_kwargs, dict):
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
flatten_args = []
|
||||||
|
pytree_map(args_or_kwargs, fn=lambda x: flatten_args.append(x), map_all=True)
|
||||||
|
args_or_kwargs = flatten_args
|
||||||
|
else:
|
||||||
|
args_or_kwargs = pytree_map(args_or_kwargs, fn=lambda x: x.wait(), process_types=Future)
|
||||||
|
if args_or_kwargs is not None:
|
||||||
|
flatten_args = []
|
||||||
|
# TODO get by offset
|
||||||
|
topo: Topo = self.get_topo()
|
||||||
|
self_partition_id = self.pp_rank_to_partition_id(self.pp_rank, topo)
|
||||||
|
self_partition: Partition = topo.get_partition_by_id(self_partition_id)
|
||||||
|
output_vals = self_partition.get_output_vals()
|
||||||
|
consumer_stage_ids = self.get_consumer_stage_ids()
|
||||||
|
for val_list in output_vals:
|
||||||
|
# An output may be passed to many down stages.
|
||||||
|
target = None
|
||||||
|
for val_pos in val_list.get():
|
||||||
|
dst_partition_id = val_pos.partition_id
|
||||||
|
dst_offset = val_pos.offset
|
||||||
|
dst_partition = topo.get_partition_by_id(dst_partition_id)
|
||||||
|
input_len = len(dst_partition.get_input_vals())
|
||||||
|
dst_stage_id = self.partition_id_to_pp_rank(dst_partition_id, topo)
|
||||||
|
for i, stage_id in enumerate(consumer_stage_ids):
|
||||||
|
if stage_id == dst_stage_id:
|
||||||
|
dst_index = i
|
||||||
|
break
|
||||||
|
if input_len == 1:
|
||||||
|
part_grad = args_or_kwargs[dst_index]
|
||||||
|
else:
|
||||||
|
part_grad = args_or_kwargs[dst_index][dst_offset]
|
||||||
|
|
||||||
|
if target is None:
|
||||||
|
target = part_grad
|
||||||
|
elif part_grad is not None:
|
||||||
|
target += part_grad
|
||||||
|
else:
|
||||||
|
continue
|
||||||
|
flatten_args.append(target)
|
||||||
|
args_or_kwargs = flatten_args
|
||||||
|
return args_or_kwargs
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def _get_work_item_key(self) -> UniqueKey:
|
def _get_work_item_key(self) -> UniqueKey:
|
||||||
"""
|
"""
|
||||||
@ -547,7 +629,7 @@ class WorkerBase(ABC):
|
|||||||
|
|
||||||
def is_last_stage(self):
|
def is_last_stage(self):
|
||||||
return self.pp_rank == self.actual_stage_num - 1
|
return self.pp_rank == self.actual_stage_num - 1
|
||||||
|
|
||||||
def need_model_input(self):
|
def need_model_input(self):
|
||||||
need_input = False
|
need_input = False
|
||||||
topo: Topo = self.get_topo()
|
topo: Topo = self.get_topo()
|
||||||
@ -558,10 +640,13 @@ class WorkerBase(ABC):
|
|||||||
if model_input_partition_id in partition_inputs:
|
if model_input_partition_id in partition_inputs:
|
||||||
need_input = True
|
need_input = True
|
||||||
return not self.is_first_stage() and need_input
|
return not self.is_first_stage() and need_input
|
||||||
|
|
||||||
def is_model_output(self):
|
def is_model_output(self):
|
||||||
return self.is_last_stage()
|
return self.is_last_stage()
|
||||||
|
|
||||||
|
def is_model_input(self):
|
||||||
|
return self.is_first_stage()
|
||||||
|
|
||||||
def _default_data_process_func(self, args_kwargs):
|
def _default_data_process_func(self, args_kwargs):
|
||||||
if self.is_first_stage():
|
if self.is_first_stage():
|
||||||
args = args_kwargs[0]
|
args = args_kwargs[0]
|
||||||
@ -598,11 +683,16 @@ class WorkerBase(ABC):
|
|||||||
|
|
||||||
# parse and integrate args and kwargs
|
# parse and integrate args and kwargs
|
||||||
if is_first_stage:
|
if is_first_stage:
|
||||||
args = self._get_real_args_kwargs(args)
|
args = self._get_real_args_kwargs_fwd(args)
|
||||||
kwargs = self._get_real_args_kwargs(kwargs)
|
kwargs = self._get_real_args_kwargs_fwd(kwargs)
|
||||||
args_kwargs = (args, kwargs)
|
args_kwargs = (args, kwargs)
|
||||||
else:
|
else:
|
||||||
args_kwargs = self._get_real_args_kwargs(args)
|
args_kwargs = self._get_real_args_kwargs_fwd(args)
|
||||||
|
|
||||||
|
if not forward_only:
|
||||||
|
pytree_map(args_kwargs,
|
||||||
|
lambda x: x.requires_grad_(True) if torch.is_floating_point(x) else x.requires_grad_(False),
|
||||||
|
process_types=torch.Tensor)
|
||||||
|
|
||||||
args, kwargs = data_process_func(args_kwargs)
|
args, kwargs = data_process_func(args_kwargs)
|
||||||
|
|
||||||
@ -694,21 +784,40 @@ class WorkerBase(ABC):
|
|||||||
|
|
||||||
# overlap recompute and future.wait
|
# overlap recompute and future.wait
|
||||||
if not is_last_stage:
|
if not is_last_stage:
|
||||||
grad_tensors = self._get_real_args_kwargs(args)
|
grad_tensors = self._get_real_args_kwargs_bwd(args)
|
||||||
else:
|
else:
|
||||||
grad_tensors = None
|
grad_tensors = None
|
||||||
|
|
||||||
# take tensor only (for only tensor can do backward)
|
# take tensor only (for only tensor can do backward)
|
||||||
stage_outputs = pytree_filter(lambda x: x.requires_grad, stage_outputs, process_types=torch.Tensor)
|
# TODO(jiangziyue) : All values which should do bp are torch.Tensor?
|
||||||
grad_tensors = pytree_filter(lambda x: x is not None, grad_tensors, process_types=torch.Tensor)
|
stage_outputs = pytree_filter(lambda x: True, stage_outputs, process_types=torch.Tensor)
|
||||||
|
grad_tensors = pytree_filter(lambda x: True, grad_tensors, process_types=torch.Tensor)
|
||||||
|
|
||||||
|
# output all input's grad to producer, even it has no grad(output None)
|
||||||
|
# to make the offset aligned to the topo's record.
|
||||||
|
if grad_tensors is not None:
|
||||||
|
filtered_outputs = []
|
||||||
|
filtered_grads = []
|
||||||
|
for i, grad in enumerate(grad_tensors):
|
||||||
|
stage_output = stage_outputs[i]
|
||||||
|
if stage_output.requires_grad and grad is not None:
|
||||||
|
filtered_outputs.append(stage_output)
|
||||||
|
filtered_grads.append(grad)
|
||||||
|
|
||||||
|
stage_outputs = filtered_outputs
|
||||||
|
grad_tensors = filtered_grads
|
||||||
|
|
||||||
autograd.backward(stage_outputs, grad_tensors=grad_tensors)
|
autograd.backward(stage_outputs, grad_tensors=grad_tensors)
|
||||||
|
|
||||||
# collect grad of input tensor
|
# collect grad of input tensor
|
||||||
consume_result = []
|
consume_result = []
|
||||||
if not is_first_stage:
|
if not is_first_stage:
|
||||||
pytree_map(stage_input_args, lambda x: consume_result.append(x.grad), process_types=torch.Tensor)
|
# In current design, input mush be a flatten args.
|
||||||
pytree_map(stage_input_kwargs, lambda x: consume_result.append(x.grad), process_types=torch.Tensor)
|
for arg in stage_input_args:
|
||||||
|
if isinstance(arg, torch.Tensor):
|
||||||
|
consume_result.append(arg.grad)
|
||||||
|
else:
|
||||||
|
consume_result.append(None)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
raise TypeError(f"Unknown phase appears in _consume_work_item_by_phase {phase}")
|
raise TypeError(f"Unknown phase appears in _consume_work_item_by_phase {phase}")
|
||||||
@ -740,11 +849,11 @@ class WorkerBase(ABC):
|
|||||||
def _hook_before_step(self):
|
def _hook_before_step(self):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def _reset_context(self):
|
# install the main loop to wait for next batch input
|
||||||
self.forward_times = 0
|
def _wait_for_reset(self):
|
||||||
self.backward_times = 0
|
with self.reset_condition:
|
||||||
self.outstanding = 0
|
self.reset_condition.wait_for(lambda: self.reset)
|
||||||
self._initialize_outstanding_range()
|
self.reset = False
|
||||||
|
|
||||||
# do the main loop to consume ready_list
|
# do the main loop to consume ready_list
|
||||||
def _work_loop(self):
|
def _work_loop(self):
|
||||||
@ -755,10 +864,9 @@ class WorkerBase(ABC):
|
|||||||
# main loop
|
# main loop
|
||||||
while True:
|
while True:
|
||||||
work_item_key = self._get_work_item_key()
|
work_item_key = self._get_work_item_key()
|
||||||
|
|
||||||
# move current work item to output_list to activate subscribe in advance
|
# move current work item to output_list to activate subscribe in advance
|
||||||
with self.work_list_condition_lock:
|
with self.work_list_condition_lock:
|
||||||
#self.work_list_condition_lock.wait_for(lambda: work_item_key in self.work_list)
|
self.work_list_condition_lock.wait_for(lambda: work_item_key in self.work_list)
|
||||||
work_item = self.work_list[work_item_key]
|
work_item = self.work_list[work_item_key]
|
||||||
|
|
||||||
with self.output_list_condition_lock:
|
with self.output_list_condition_lock:
|
||||||
@ -768,16 +876,32 @@ class WorkerBase(ABC):
|
|||||||
|
|
||||||
consume_result = self._consume_work_item_by_phase(work_item)
|
consume_result = self._consume_work_item_by_phase(work_item)
|
||||||
|
|
||||||
work_item.output.set_result(consume_result)
|
|
||||||
with self.work_list_condition_lock:
|
with self.work_list_condition_lock:
|
||||||
self.work_list.pop(work_item_key)
|
self.work_list.pop(work_item_key)
|
||||||
|
work_item.output.set_result(consume_result)
|
||||||
|
|
||||||
# if is last step in one batch reset context and do step
|
# if is last step in one batch reset context and do step
|
||||||
if self._is_last_step(work_item):
|
if self._is_last_step(work_item):
|
||||||
self._hook_before_step()
|
self._hook_before_step()
|
||||||
if hasattr(self, 'optimizer') and not work_item.forward_only:
|
if hasattr(self, 'optimizer') and not work_item.forward_only:
|
||||||
self.step()
|
self.step()
|
||||||
self._reset_context()
|
self._wait_for_reset()
|
||||||
|
|
||||||
|
# reset context and resume loop
|
||||||
|
def reset_context(self):
|
||||||
|
self.forward_times = 0
|
||||||
|
self.backward_times = 0
|
||||||
|
self.outstanding = 0
|
||||||
|
self._initialize_outstanding_range()
|
||||||
|
with self.work_list_condition_lock:
|
||||||
|
self.work_list.clear()
|
||||||
|
|
||||||
|
with self.output_list_condition_lock:
|
||||||
|
self.output_list.clear()
|
||||||
|
|
||||||
|
with self.reset_condition:
|
||||||
|
self.reset = True
|
||||||
|
self.reset_condition.notify_all()
|
||||||
|
|
||||||
def initialize_optimizer(self, optimizer_class: type, **kwargs):
|
def initialize_optimizer(self, optimizer_class: type, **kwargs):
|
||||||
# TODO(jiangziyue) it's temporary code to deal with empty module partition.
|
# TODO(jiangziyue) it's temporary code to deal with empty module partition.
|
||||||
@ -856,7 +980,7 @@ class PipelineEngineBase(ABC, nn.Module):
|
|||||||
|
|
||||||
def _create_pp_rank_to_rpc_worker_id(self) -> None:
|
def _create_pp_rank_to_rpc_worker_id(self) -> None:
|
||||||
"""create a map from model partition to stage_id, which is useful when use_interleave is True.
|
"""create a map from model partition to stage_id, which is useful when use_interleave is True.
|
||||||
e.g. If a model is splited into 4 parts, which means stage_num is 2, chunk is 2, then
|
e.g. If a model is splited into 4 parts, which means stage_num is 2, chunk is 2, then
|
||||||
pp_rank_to_rpc_worker_id = [0, 1, 0, 1], that means first and third part
|
pp_rank_to_rpc_worker_id = [0, 1, 0, 1], that means first and third part
|
||||||
of partitions will be moved to device 0 and the others to device 1
|
of partitions will be moved to device 0 and the others to device 1
|
||||||
"""
|
"""
|
||||||
@ -947,7 +1071,7 @@ class PipelineEngineBase(ABC, nn.Module):
|
|||||||
key = UniqueKey(microbatch_id - actual_stage_num, Phase.BACKWARD)
|
key = UniqueKey(microbatch_id - actual_stage_num, Phase.BACKWARD)
|
||||||
for pp_rank in input_pp_ranks:
|
for pp_rank in input_pp_ranks:
|
||||||
worker_rref = self.pp_rank_to_worker_rref[pp_rank]
|
worker_rref = self.pp_rank_to_worker_rref[pp_rank]
|
||||||
worker_rref.rpc_sync().get_output_by_key(key)
|
worker_rref.rpc_sync().get_output_by_key(key, ref_use=True)
|
||||||
|
|
||||||
def _create_ret_future(self, output_pp_ranks: List[int]) -> Dict[int, List[Future]]:
|
def _create_ret_future(self, output_pp_ranks: List[int]) -> Dict[int, List[Future]]:
|
||||||
num_microbatches = self.num_microbatches
|
num_microbatches = self.num_microbatches
|
||||||
@ -965,6 +1089,7 @@ class PipelineEngineBase(ABC, nn.Module):
|
|||||||
# TODO : add relationship between output_pp_ranks and parts of microlabels
|
# TODO : add relationship between output_pp_ranks and parts of microlabels
|
||||||
worker_rref.remote().set_labels(microbatch_id, microlabels)
|
worker_rref.remote().set_labels(microbatch_id, microlabels)
|
||||||
|
|
||||||
|
# TODO(jiangziyue) : get model output with single value, instead of merging into last stage.
|
||||||
def _subscribe_forward(self, microbatch_id: int, output_pp_ranks: List[int], ret_future: Dict[int, List[Future]]):
|
def _subscribe_forward(self, microbatch_id: int, output_pp_ranks: List[int], ret_future: Dict[int, List[Future]]):
|
||||||
key = UniqueKey(microbatch_id, Phase.FORWARD)
|
key = UniqueKey(microbatch_id, Phase.FORWARD)
|
||||||
for pp_rank in output_pp_ranks:
|
for pp_rank in output_pp_ranks:
|
||||||
@ -993,6 +1118,16 @@ class PipelineEngineBase(ABC, nn.Module):
|
|||||||
|
|
||||||
return forward_result
|
return forward_result
|
||||||
|
|
||||||
|
def _reset_worker(self):
|
||||||
|
actual_stage_num = self._get_actual_stage_num()
|
||||||
|
for pp_rank in range(actual_stage_num):
|
||||||
|
worker_rref = self.pp_rank_to_worker_rref[pp_rank]
|
||||||
|
fut = worker_rref.rpc_async().reset_context()
|
||||||
|
self.step_futs.append(fut)
|
||||||
|
|
||||||
|
for fut in self.step_futs:
|
||||||
|
fut.wait()
|
||||||
|
|
||||||
def forward_backward(self, batch: torch.Tensor, labels: torch.Tensor = None, forward_only: bool = False):
|
def forward_backward(self, batch: torch.Tensor, labels: torch.Tensor = None, forward_only: bool = False):
|
||||||
batch_lengths = get_batch_lengths(batch)
|
batch_lengths = get_batch_lengths(batch)
|
||||||
batch_length = batch_lengths[0]
|
batch_length = batch_lengths[0]
|
||||||
@ -1046,6 +1181,7 @@ class PipelineEngineBase(ABC, nn.Module):
|
|||||||
worker_rref = self.pp_rank_to_worker_rref[pp_rank]
|
worker_rref = self.pp_rank_to_worker_rref[pp_rank]
|
||||||
worker_rref.rpc_sync().wait_for_step()
|
worker_rref.rpc_sync().wait_for_step()
|
||||||
|
|
||||||
|
self._reset_worker() # reset worker attributes for next batch
|
||||||
return forward_result
|
return forward_result
|
||||||
|
|
||||||
def initialize_optimizer(self, optimizer_class: type, **kwargs):
|
def initialize_optimizer(self, optimizer_class: type, **kwargs):
|
||||||
|
@ -89,9 +89,6 @@ class OneFOneBWorker(WorkerBase):
|
|||||||
elif target_key.microbatch_id == num_microbatches - 1:
|
elif target_key.microbatch_id == num_microbatches - 1:
|
||||||
self.outstanding_range = (0, 0)
|
self.outstanding_range = (0, 0)
|
||||||
|
|
||||||
with self.work_list_condition_lock:
|
|
||||||
self.work_list_condition_lock.wait_for(lambda: target_key in self.work_list)
|
|
||||||
|
|
||||||
return target_key
|
return target_key
|
||||||
|
|
||||||
|
|
||||||
|
@ -57,7 +57,6 @@ def split_batch(batch: Any, start, stop, device: str):
|
|||||||
def type_detail(obj):
|
def type_detail(obj):
|
||||||
return pytree_map(obj, lambda x: type(x), map_all=True)
|
return pytree_map(obj, lambda x: type(x), map_all=True)
|
||||||
|
|
||||||
|
|
||||||
def pytree_filter(fn, obj, process_types):
|
def pytree_filter(fn, obj, process_types):
|
||||||
if obj is None:
|
if obj is None:
|
||||||
return None
|
return None
|
||||||
|
@ -31,7 +31,7 @@ class MLP(nn.Module):
|
|||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
for layer in self.layers:
|
for layer in self.layers:
|
||||||
x = layer(x)
|
x = layer(x)
|
||||||
return x
|
return x.sum()
|
||||||
|
|
||||||
class DAG_MLP(nn.Module):
|
class DAG_MLP(nn.Module):
|
||||||
def __init__(self, dim: int, layers: int):
|
def __init__(self, dim: int, layers: int):
|
||||||
@ -46,7 +46,7 @@ class DAG_MLP(nn.Module):
|
|||||||
for layer in self.layers:
|
for layer in self.layers:
|
||||||
x = layer(x)
|
x = layer(x)
|
||||||
y = self.dag_layer(y)
|
y = self.dag_layer(y)
|
||||||
return x, y
|
return x.sum(), y.sum()
|
||||||
|
|
||||||
class RpcTestModel(nn.Module):
|
class RpcTestModel(nn.Module):
|
||||||
|
|
||||||
|
@ -41,10 +41,10 @@ def partition(model, data_kwargs: dict, pp_rank: int, chunk: int, stage_num: int
|
|||||||
partition = create_partition_module(pp_rank, stage_num, model, data_kwargs)
|
partition = create_partition_module(pp_rank, stage_num, model, data_kwargs)
|
||||||
return partition
|
return partition
|
||||||
|
|
||||||
def run_master(model_cls, world_size):
|
def run_master(model_cls, world_size, forward_only):
|
||||||
torch.manual_seed(100)
|
torch.manual_seed(100)
|
||||||
|
|
||||||
epoch = 10
|
epoch = 3
|
||||||
device = 'cuda'
|
device = 'cuda'
|
||||||
stage_num = world_size
|
stage_num = world_size
|
||||||
chunk = 1
|
chunk = 1
|
||||||
@ -57,6 +57,10 @@ def run_master(model_cls, world_size):
|
|||||||
kwargs = dict(x=x)
|
kwargs = dict(x=x)
|
||||||
return kwargs
|
return kwargs
|
||||||
model = model_cls(dim, stage_num * 3)
|
model = model_cls(dim, stage_num * 3)
|
||||||
|
if forward_only:
|
||||||
|
labels = None
|
||||||
|
else:
|
||||||
|
labels = 1
|
||||||
elif model_cls == DAG_MLP:
|
elif model_cls == DAG_MLP:
|
||||||
def data_gen():
|
def data_gen():
|
||||||
x = torch.zeros((batch_size, dim))
|
x = torch.zeros((batch_size, dim))
|
||||||
@ -64,24 +68,30 @@ def run_master(model_cls, world_size):
|
|||||||
kwargs = dict(x=x, y=y)
|
kwargs = dict(x=x, y=y)
|
||||||
return kwargs
|
return kwargs
|
||||||
model = model_cls(dim, stage_num * 3)
|
model = model_cls(dim, stage_num * 3)
|
||||||
|
if forward_only:
|
||||||
|
labels = None
|
||||||
|
else:
|
||||||
|
labels = 1
|
||||||
else:
|
else:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
data_kwargs = data_gen()
|
data_kwargs = data_gen()
|
||||||
|
|
||||||
engine = OneFOneBPipelineEngine(partition_fn=partial(partition, model, data_kwargs),
|
engine = OneFOneBPipelineEngine(partition_fn=partial(partition, model, data_kwargs),
|
||||||
stage_num=stage_num,
|
stage_num=stage_num,
|
||||||
num_microbatches=num_microbatches,
|
num_microbatches=num_microbatches,
|
||||||
device=device,
|
device=device,
|
||||||
chunk=chunk,
|
chunk=chunk,
|
||||||
checkpoint=use_checkpoint,)
|
checkpoint=use_checkpoint,)
|
||||||
|
if not forward_only:
|
||||||
|
engine.initialize_optimizer(getattr(torch.optim, 'SGD'), lr=1e-3)
|
||||||
|
|
||||||
for _ in range(epoch):
|
for _ in range(epoch):
|
||||||
input_x = torch.randn((batch_size, dim), device=device)
|
input_x = torch.randn((batch_size, dim), device=device)
|
||||||
input_y = torch.randn((batch_size, dim), device=device)
|
input_y = torch.randn((batch_size, dim), device=device)
|
||||||
logits = engine.forward_backward({'x': input_x, 'y': input_y}, forward_only=True)
|
logits = engine.forward_backward({'x': input_x, 'y': input_y}, labels=labels, forward_only=forward_only)
|
||||||
|
|
||||||
def run_worker(rank, model_cls, world_size, master_func):
|
def run_worker(rank, model_cls, world_size, forward_only, master_func):
|
||||||
master_addr = 'localhost'
|
master_addr = 'localhost'
|
||||||
master_port = 29020
|
master_port = 29020
|
||||||
os.environ['MASTER_ADDR'] = master_addr
|
os.environ['MASTER_ADDR'] = master_addr
|
||||||
@ -99,19 +109,20 @@ def run_worker(rank, model_cls, world_size, master_func):
|
|||||||
|
|
||||||
# in rpc mode, only rank 0 is needed to be coded
|
# in rpc mode, only rank 0 is needed to be coded
|
||||||
if rank == 0:
|
if rank == 0:
|
||||||
master_func(model_cls, world_size)
|
master_func(model_cls, world_size, forward_only)
|
||||||
# barrier here
|
# barrier here
|
||||||
if rpc_is_initialized():
|
if rpc_is_initialized():
|
||||||
rpc.shutdown()
|
rpc.shutdown()
|
||||||
|
|
||||||
@pytest.mark.skip("skip due to CI torch version 1.11")
|
@pytest.mark.skip("skip due to CI torch version 1.11")
|
||||||
@parameterize('model_cls', [MLP, DAG_MLP])
|
@parameterize('model_cls', [MLP, DAG_MLP])
|
||||||
|
@parameterize('forward_only', [True, False])
|
||||||
@pytest.mark.dist
|
@pytest.mark.dist
|
||||||
@rerun_if_address_is_in_use()
|
@rerun_if_address_is_in_use()
|
||||||
def test_pp_middleware_fwd(model_cls):
|
def test_pp_middleware_fwd(model_cls, forward_only):
|
||||||
world_size = 4
|
world_size = 4
|
||||||
master_func = run_master
|
master_func = run_master
|
||||||
mp.spawn(run_worker, args=(model_cls, world_size, master_func), nprocs=world_size)
|
mp.spawn(run_worker, args=(model_cls, world_size, forward_only, master_func), nprocs=world_size)
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
test_pp_middleware_fwd()
|
test_pp_middleware_fwd()
|
Loading…
Reference in New Issue
Block a user