[rpc] split with dag (#2028)

* add DAG to split_module

* add comment

* add test case for DAG

* remove print

* add DAG middleware in scheduler

* add test case for scheduler

* remove break

* recover old lifecycle

Co-authored-by: Ziyue Jiang <ziyue.jiang@gmail.com>
This commit is contained in:
Ziyue Jiang
2022-11-29 11:36:28 +08:00
committed by GitHub
parent 96134e7be3
commit b0936e4a44
5 changed files with 337 additions and 53 deletions

View File

@@ -117,7 +117,7 @@ def uniform_split_pass(gm: torch.fx.GraphModule, pp_size: int):
return gm
def split_with_split_nodes_pass(annotated_gm: torch.fx.GraphModule):
def split_with_split_nodes_pass(annotated_gm: torch.fx.GraphModule, merge_output=False):
# TODO(lyl): use partition IR to assign partition ID to each node.
# Currently: analyzing graph -> annotate graph by inserting split node -> use split module pass to split graph
# In future: graph to partitions -> analyzing partition IR -> recombining partitions to get best performance -> assign partition ID to each node
@@ -129,7 +129,7 @@ def split_with_split_nodes_pass(annotated_gm: torch.fx.GraphModule):
part_idx += 1
return part_idx
split_mod = split_module(annotated_gm, None, split_callback)
split_mod = split_module(annotated_gm, None, split_callback, merge_output)
split_submodules = []
for name, submodule in split_mod.named_modules():
if isinstance(submodule, torch.fx.GraphModule):

View File

@@ -199,24 +199,17 @@ def find_user_in_partition(node, partitions, output_partitions=None, direct=Fals
for partition in partitions:
if node == partition:
user_partition_names.append(partition.name)
# find user with getitem call
else:
for partition in partitions:
if node in partition.args:
user_partition_names.append(partition.name)
is_output = False
def find_output(def_node, output_node):
nonlocal is_output
if def_node == output_node:
is_output = True
if output_partitions is not None:
output_node = output_partitions[0]
torch.fx.graph.map_arg(output_node.args[0], lambda n: find_output(node, n))
if is_output:
user_partition_names.append('MODEL_OUTPUT')
if node.op == output_node.op:
user_partition_names.append('MODEL_OUTPUT')
if len(user_partition_names) > 0:
return user_partition_names

View File

@@ -9,7 +9,7 @@ from typing import Any, Callable, Dict, List, Tuple
import torch
import torch.distributed.rpc as rpc
from colossalai.pipeline.pipeline_process_group import ppg
from colossalai.pipeline.rpc.utils import (get_batch_lengths, get_real_args_kwargs, pytree_filter, pytree_map,
from colossalai.pipeline.rpc.utils import (get_batch_lengths, pytree_filter, pytree_map,
split_batch, tensor_shape_list, type_detail)
from torch import autograd, nn, optim
from torch._C._distributed_rpc import PyRRef
@@ -20,7 +20,7 @@ class Phase(Enum):
FORWARD = 0
BACKWARD = 1
UPDATE = 2
INPUT = 3
class UniqueKey:
__slots__ = ('microbatch_id', 'phase')
@@ -128,6 +128,7 @@ class WorkerBase(ABC):
# topology info
self.producer_stage_ids: List[int] = None
self.consumer_stage_ids: List[int] = None
self.input_consumer_stage_ids: List[int] = None
# module partitions
self.partition_fn = partition_fn
@@ -135,6 +136,11 @@ class WorkerBase(ABC):
self.criterion = criterion
self.metric = metric
# middleware info
self._is_input = False
self._is_output = False
self._producer_consumer_initialized = False
# context to maintain loop
self._initialize_context_container()
@@ -164,6 +170,7 @@ class WorkerBase(ABC):
self.work_list_condition_lock = threading.Condition(threading.Lock())
self.output_list_condition_lock = threading.Condition(threading.Lock())
self.label_lock = threading.Condition(threading.Lock())
self.producer_consumer_init_lock = threading.Condition(threading.Lock())
def _initialize_partition(self):
partition_fn = self.partition_fn
@@ -182,7 +189,7 @@ class WorkerBase(ABC):
# construction of partition is executed after the registion of pp_rank_to_worker_rref
self._initialize_partition()
def get_output_by_key(self, key: UniqueKey) -> Any:
def get_output_by_key(self, key: UniqueKey, recv_rank=None) -> Any:
with self.output_list_condition_lock:
self.output_list_condition_lock.wait_for(lambda: key in self.output_list)
output_work_item = self.output_list[key]
@@ -191,8 +198,9 @@ class WorkerBase(ABC):
if isinstance(output, Future):
output = output.wait()
output_work_item.refcount += 1
# output_work_item.refcount += 1
# TODO(jiangziyue) redesign lifecycle management for DAG scheduler
# all consumers have been satisfied, the work_item can be released
with self.output_list_condition_lock:
if output_work_item.refcount >= len(self.consumer_stage_ids):
@@ -215,8 +223,10 @@ class WorkerBase(ABC):
self.partition_condition_lock.wait_for(lambda: hasattr(self, 'module_partition'))
return self.module_partition.state_dict()
def _make_args_kwargs(self, microbatch):
def _make_args_kwargs(self, microbatch, merge=False):
if isinstance(microbatch, dict):
if merge:
return list(microbatch.values()), {}
return [], microbatch
elif isinstance(microbatch, torch.Tensor):
return [microbatch], {}
@@ -228,24 +238,70 @@ class WorkerBase(ABC):
kwargs.update(arg)
else:
args.append(arg)
if merge:
arg_lst = args
for arg in kwargs.values():
arg_lst.append(arg)
return arg_lst, {}
return args, kwargs
else:
raise TypeError(f"Input batch can be only dict, list, tuple or tensor, but receive {type(microbatch)}")
# just for first pp_rank
# TODO(jiangziyue) Consider whether this function should be protected by Lock in DAG env.
# TODO(jiangziyue) Define a Class for DAG.
def set_input(self, microbatch_id: int, microbatch: Tuple[Any], forward_only: bool):
assert self.consumer_stage_ids is not None
key = UniqueKey(microbatch_id, Phase.FORWARD)
output = self._get_future_by_device()
if not self.use_middleware():
# make args and kwargs
args, kwargs = self._make_args_kwargs(microbatch)
# make args and kwargs
args, kwargs = self._make_args_kwargs(microbatch)
work_item = WorkItem(self.pp_rank, Phase.FORWARD, args, kwargs, output, microbatch_id, None,
self.num_microbatches, forward_only)
with self.work_list_condition_lock:
self.work_list[key] = work_item
self.work_list_condition_lock.notify_all()
else:
# make args and kwargs
arg_lst, _ = self._make_args_kwargs(microbatch, merge=True)
work_item = WorkItem(self.pp_rank, Phase.FORWARD, args, kwargs, output, microbatch_id, None,
self.num_microbatches, forward_only)
with self.work_list_condition_lock:
self.work_list[key] = work_item
self.work_list_condition_lock.notify_all()
# first stage assign correct input into other stages
DAG = self.get_DAG()
DAG_node = DAG['input_partition']
self_input_offsets = []
recv_input_key = UniqueKey(microbatch_id, Phase.INPUT)
# notify rank which should receive extra input
offset = 0
for details in DAG_node.values():
for partition_name in details['output'].keys():
recv_rank = self.partition_name_to_pp_rank(partition_name)
if recv_rank == self.pp_rank:
self_input_offsets.append(offset)
elif recv_rank not in self.input_consumer_stage_ids:
self.input_consumer_stage_ids.append(recv_rank)
offset += 1
# set input for self rank
self_arg_lst = []
for off in self_input_offsets:
self_arg_lst.append(arg_lst[off])
work_item = WorkItem(self.pp_rank, Phase.FORWARD, self_arg_lst, {}, output, microbatch_id, None,
self.num_microbatches, forward_only)
with self.work_list_condition_lock:
self.work_list[key] = work_item
self.work_list_condition_lock.notify_all()
# put input tensor which other nodes need into output_list
work_item_remote = WorkItem(self.pp_rank, Phase.INPUT, [], {}, arg_lst, microbatch_id, None,
self.num_microbatches, forward_only)
with self.output_list_condition_lock:
self.output_list[recv_input_key] = work_item_remote
self.output_list_condition_lock.notify_all()
# just for last pp_rank
def set_labels(self, microbatch_id: int, microlabels: Any):
@@ -268,33 +324,68 @@ class WorkerBase(ABC):
self.work_list[key] = work_item
self.work_list_condition_lock.notify_all()
# TODO(jiangziyue) Consider whether this function should be protected by Lock in DAG env.
def subscribe_producer(self, microbatch_id: int, forward_only: bool):
"""
You should call this function asynchronously
"""
assert self.producer_stage_ids is not None
producer_num = len(self.producer_stage_ids)
assert producer_num > 0, "only stage that has producers can subscribe producers"
stage_id = self.pp_rank
subscribe_forward_futures: List[Future] = [None] * producer_num
output = self._get_future_by_device()
if not self.use_middleware():
producer_num = len(self.producer_stage_ids)
subscribe_forward_futures: List[Future] = [None] * producer_num
for i in range(producer_num):
producer_stage_id = self.producer_stage_ids[i]
producer_output_key = UniqueKey(microbatch_id, Phase.FORWARD)
producer_worker_rref = self.pp_rank_to_worker_rref[producer_stage_id]
subscribe_forward_futures[i] = producer_worker_rref.rpc_async().get_output_by_key(producer_output_key)
else:
with self.work_list_condition_lock:
key = UniqueKey(microbatch_id, Phase.FORWARD)
if key in self.work_list:
return
for i in range(producer_num):
producer_stage_id = self.producer_stage_ids[i]
producer_output_key = UniqueKey(microbatch_id, Phase.FORWARD)
producer_worker_rref = self.pp_rank_to_worker_rref[producer_stage_id]
subscribe_forward_futures[i] = producer_worker_rref.rpc_async().get_output_by_key(producer_output_key)
producer_stage_ids = []
with self.producer_consumer_init_lock:
self.producer_consumer_init_lock.wait_for(lambda: self._producer_consumer_initialized)
producer_stage_ids = self.producer_stage_ids
producer_num = len(producer_stage_ids)
# TODO(jiangziyue) get single value instead of the whole output
if self.need_model_input():
producer_num += 1 # extra one(the last one) for input_tensor
subscribe_forward_futures: List[Future] = [None] * producer_num
# TODO(jiangziyue) get single value instead of the whole output
if self.need_model_input():
producer_stage_id = 0
producer_output_key = UniqueKey(microbatch_id, Phase.INPUT)
producer_worker_rref = self.pp_rank_to_worker_rref[producer_stage_id]
subscribe_forward_futures[0] = producer_worker_rref.rpc_async().get_output_by_key(producer_output_key, self.pp_rank)
for i in range(0, producer_num-1):
producer_stage_id = producer_stage_ids[i]
producer_output_key = UniqueKey(microbatch_id, Phase.FORWARD)
producer_worker_rref = self.pp_rank_to_worker_rref[producer_stage_id]
subscribe_forward_futures[i+1] = producer_worker_rref.rpc_async().get_output_by_key(producer_output_key, self.pp_rank)
else:
for i in range(producer_num):
producer_stage_id = producer_stage_ids[i]
producer_output_key = UniqueKey(microbatch_id, Phase.FORWARD)
producer_worker_rref = self.pp_rank_to_worker_rref[producer_stage_id]
#producer_partition_name = self.pp_rank_to_partition_name[producer_stage_id]
subscribe_forward_futures[i] = producer_worker_rref.rpc_async().get_output_by_key(producer_output_key, self.pp_rank)
work_item_from_producer = WorkItem(stage_id, Phase.FORWARD, subscribe_forward_futures, {}, output,
microbatch_id, None, self.num_microbatches, forward_only)
microbatch_id, None, self.num_microbatches, forward_only)
# add work_item to work_list
with self.work_list_condition_lock:
key = UniqueKey(microbatch_id, Phase.FORWARD)
assert key not in self.work_list
self.work_list[key] = work_item_from_producer
self.work_list_condition_lock.notify_all()
if key not in self.work_list:
self.work_list[key] = work_item_from_producer
self.work_list_condition_lock.notify_all()
def subscribe_consumer(self, microbatch_id: int):
"""
@@ -334,13 +425,132 @@ class WorkerBase(ABC):
self.producer_stage_ids = []
self.consumer_stage_ids = []
# Just for demo
prev_rank = rank - 1
next_rank = rank + 1
if prev_rank >= 0:
self.producer_stage_ids.append(prev_rank)
if next_rank <= self.actual_stage_num - 1:
self.consumer_stage_ids.append(next_rank)
if not self.use_middleware():
# Just for demo
prev_rank = rank - 1
next_rank = rank + 1
if prev_rank >= 0:
self.producer_stage_ids.append(prev_rank)
if next_rank <= self.actual_stage_num - 1:
self.consumer_stage_ids.append(next_rank)
else:
self.input_consumer_stage_ids = []
DAG = self.get_DAG()
DAG_node_name = self.pp_rank_to_partition_name(rank)
DAG_node = DAG[DAG_node_name]
for partition_name in DAG_node['input'].keys():
if partition_name == 'MODEL_INPUT':
self._is_input = True
else:
prev_rank = self.partition_name_to_pp_rank(partition_name)
self.producer_stage_ids.append(prev_rank)
for partition_name in DAG_node['output'].keys():
if partition_name == 'MODEL_OUTPUT':
self._is_output = True
else:
next_rank = self.partition_name_to_pp_rank(partition_name)
self.consumer_stage_ids.append(next_rank)
# TODO(jiangziyue) Consider whether this function should be protected by Lock in DAG env.
with self.producer_consumer_init_lock:
self._producer_consumer_initialized = True
self.producer_consumer_init_lock.notify_all()
# TODO(jiangziyue) Define a Class for DAG.
def pp_rank_to_partition_name(self, pp_rank: int):
prefix = 'submod_'
partition_name = prefix + str(pp_rank)
return partition_name
# TODO(jiangziyue) Define a Class for DAG.
def partition_name_to_pp_rank(self, partition_name: str) -> int:
prefix = 'submod_'
pp_rank = int(partition_name.split(prefix)[-1])
return pp_rank
def get_DAG(self):
with self.partition_condition_lock:
self.partition_condition_lock.wait_for(lambda: hasattr(self, 'module_partition'))
if hasattr(self.module_partition, '_DAG'):
return self.module_partition._DAG
else:
return None
def use_middleware(self):
DAG = self.get_DAG()
return DAG is not None
# TODO(jiangziyue) get single value instead of the whole output
def _get_real_args_kwargs(self, args_or_kwargs):
if not self.use_middleware():
args_or_kwargs = pytree_map(args_or_kwargs, fn=lambda x: x.wait(), process_types=Future)
if args_or_kwargs is not None:
if isinstance(args_or_kwargs, dict):
pass
else:
flatten_args = []
pytree_map(args_or_kwargs, fn=lambda x: flatten_args.append(x), map_all=True)
args_or_kwargs = flatten_args
else:
args_or_kwargs = pytree_map(args_or_kwargs, fn=lambda x: x.wait(), process_types=Future)
if args_or_kwargs is not None:
if isinstance(args_or_kwargs, dict):
pass
else:
flatten_args = []
if self.is_first_stage():
pytree_map(args_or_kwargs, fn=lambda x: flatten_args.append(x), map_all=True)
# TODO get by offset
else:
DAG = self.get_DAG()
producer_outputs = {}
cur_DAG_node_name = self.pp_rank_to_partition_name(self.pp_rank)
#cur_DAG_node = DAG[self.pp_rank_to_partition_name(self.pp_rank)]
for i, args_from_one_mod in enumerate(args_or_kwargs):
producer_output_offsets = []
if self.need_model_input():
if i == 0:
producer_DAG_node = DAG['input_partition']
producer_partition_name = 'MODEL_INPUT'
offset = 0
for arg_info in producer_DAG_node.values():
if cur_DAG_node_name in arg_info['output']:
producer_output_offsets.append(offset)
offset += 1
else:
producer_rank = self.producer_stage_ids[i-1]
producer_partition_name = self.pp_rank_to_partition_name(producer_rank)
producer_DAG_node = DAG[producer_partition_name]
producer_output_offsets = producer_DAG_node['output'][cur_DAG_node_name]
else:
producer_rank = self.producer_stage_ids[i]
producer_partition_name = self.pp_rank_to_partition_name(producer_rank)
producer_DAG_node = DAG[producer_partition_name]
producer_output_offsets = producer_DAG_node['output'][cur_DAG_node_name]
if producer_partition_name != 'MODEL_INPUT' and DAG[producer_partition_name]['output_len'] == 1:
producer_outputs[producer_partition_name] = [args_from_one_mod]
else:
producer_outputs[producer_partition_name] = [args_from_one_mod[offset] for offset in producer_output_offsets]
cur_DAG_node_input = DAG[cur_DAG_node_name]['input']
def get_input_len(DAG_node_input):
res = 0
for offsets in DAG_node_input.values():
res += len(offsets)
return res
input_len = get_input_len(cur_DAG_node_input)
flatten_args = [None] * input_len
for producer_partition_name, args_input_offsets in cur_DAG_node_input.items():
for i, args_input_offset in enumerate(args_input_offsets):
flatten_args[args_input_offset] = producer_outputs[producer_partition_name][i]
args_or_kwargs = flatten_args
return args_or_kwargs
@abstractmethod
def _get_work_item_key(self) -> UniqueKey:
@@ -353,6 +563,9 @@ class WorkerBase(ABC):
def is_last_stage(self):
return self.pp_rank == self.actual_stage_num - 1
def need_model_input(self):
return not self.is_first_stage() and self._is_input
def _default_data_process_func(self, args_kwargs):
if self.is_first_stage():
@@ -390,11 +603,11 @@ class WorkerBase(ABC):
# parse and integrate args and kwargs
if is_first_stage:
args = get_real_args_kwargs(args)
kwargs = get_real_args_kwargs(kwargs)
args = self._get_real_args_kwargs(args)
kwargs = self._get_real_args_kwargs(kwargs)
args_kwargs = (args, kwargs)
else:
args_kwargs = get_real_args_kwargs(args)
args_kwargs = self._get_real_args_kwargs(args)
args, kwargs = data_process_func(args_kwargs)
@@ -486,7 +699,7 @@ class WorkerBase(ABC):
# overlap recompute and future.wait
if not is_last_stage:
grad_tensors = get_real_args_kwargs(args)
grad_tensors = self._get_real_args_kwargs(args)
else:
grad_tensors = None
@@ -569,7 +782,10 @@ class WorkerBase(ABC):
self._reset_context()
def initialize_optimizer(self, optimizer_class: type, **kwargs):
self.optimizer: optim.Optimizer = optimizer_class(self.module_partition.parameters(), **kwargs)
# TODO(jiangziyue) it's temporary code to deal with empty module partition.
# After tracer fixed, remove this part.
if len(list(self.module_partition.parameters())) > 0:
self.optimizer: optim.Optimizer = optimizer_class(self.module_partition.parameters(), **kwargs)
self.step_lock = threading.Lock()
self.step_lock.acquire()
@@ -577,8 +793,11 @@ class WorkerBase(ABC):
self.step_lock.acquire()
def step(self):
self.optimizer.step()
self.optimizer.zero_grad()
# TODO(jiangziyue) it's temporary code to deal with empty module partition.
# After tracer fixed, remove this part.
if len(list(self.module_partition.parameters())) > 0:
self.optimizer.step()
self.optimizer.zero_grad()
self.step_lock.release()