[rpc] split with dag (#2028)

* add DAG to split_module

* add comment

* add test case for DAG

* remove print

* add DAG middleware in scheduler

* add test case for scheduler

* remove break

* recover old lifecycle

Co-authored-by: Ziyue Jiang <ziyue.jiang@gmail.com>
This commit is contained in:
Ziyue Jiang 2022-11-29 11:36:28 +08:00 committed by GitHub
parent 96134e7be3
commit b0936e4a44
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 337 additions and 53 deletions

View File

@ -117,7 +117,7 @@ def uniform_split_pass(gm: torch.fx.GraphModule, pp_size: int):
return gm
def split_with_split_nodes_pass(annotated_gm: torch.fx.GraphModule):
def split_with_split_nodes_pass(annotated_gm: torch.fx.GraphModule, merge_output=False):
# TODO(lyl): use partition IR to assign partition ID to each node.
# Currently: analyzing graph -> annotate graph by inserting split node -> use split module pass to split graph
# In future: graph to partitions -> analyzing partition IR -> recombining partitions to get best performance -> assign partition ID to each node
@ -129,7 +129,7 @@ def split_with_split_nodes_pass(annotated_gm: torch.fx.GraphModule):
part_idx += 1
return part_idx
split_mod = split_module(annotated_gm, None, split_callback)
split_mod = split_module(annotated_gm, None, split_callback, merge_output)
split_submodules = []
for name, submodule in split_mod.named_modules():
if isinstance(submodule, torch.fx.GraphModule):

View File

@ -199,24 +199,17 @@ def find_user_in_partition(node, partitions, output_partitions=None, direct=Fals
for partition in partitions:
if node == partition:
user_partition_names.append(partition.name)
# find user with getitem call
else:
for partition in partitions:
if node in partition.args:
user_partition_names.append(partition.name)
is_output = False
def find_output(def_node, output_node):
nonlocal is_output
if def_node == output_node:
is_output = True
if output_partitions is not None:
output_node = output_partitions[0]
torch.fx.graph.map_arg(output_node.args[0], lambda n: find_output(node, n))
if is_output:
user_partition_names.append('MODEL_OUTPUT')
if node.op == output_node.op:
user_partition_names.append('MODEL_OUTPUT')
if len(user_partition_names) > 0:
return user_partition_names

View File

@ -9,7 +9,7 @@ from typing import Any, Callable, Dict, List, Tuple
import torch
import torch.distributed.rpc as rpc
from colossalai.pipeline.pipeline_process_group import ppg
from colossalai.pipeline.rpc.utils import (get_batch_lengths, get_real_args_kwargs, pytree_filter, pytree_map,
from colossalai.pipeline.rpc.utils import (get_batch_lengths, pytree_filter, pytree_map,
split_batch, tensor_shape_list, type_detail)
from torch import autograd, nn, optim
from torch._C._distributed_rpc import PyRRef
@ -20,7 +20,7 @@ class Phase(Enum):
FORWARD = 0
BACKWARD = 1
UPDATE = 2
INPUT = 3
class UniqueKey:
__slots__ = ('microbatch_id', 'phase')
@ -128,6 +128,7 @@ class WorkerBase(ABC):
# topology info
self.producer_stage_ids: List[int] = None
self.consumer_stage_ids: List[int] = None
self.input_consumer_stage_ids: List[int] = None
# module partitions
self.partition_fn = partition_fn
@ -135,6 +136,11 @@ class WorkerBase(ABC):
self.criterion = criterion
self.metric = metric
# middleware info
self._is_input = False
self._is_output = False
self._producer_consumer_initialized = False
# context to maintain loop
self._initialize_context_container()
@ -164,6 +170,7 @@ class WorkerBase(ABC):
self.work_list_condition_lock = threading.Condition(threading.Lock())
self.output_list_condition_lock = threading.Condition(threading.Lock())
self.label_lock = threading.Condition(threading.Lock())
self.producer_consumer_init_lock = threading.Condition(threading.Lock())
def _initialize_partition(self):
partition_fn = self.partition_fn
@ -182,7 +189,7 @@ class WorkerBase(ABC):
# construction of partition is executed after the registion of pp_rank_to_worker_rref
self._initialize_partition()
def get_output_by_key(self, key: UniqueKey) -> Any:
def get_output_by_key(self, key: UniqueKey, recv_rank=None) -> Any:
with self.output_list_condition_lock:
self.output_list_condition_lock.wait_for(lambda: key in self.output_list)
output_work_item = self.output_list[key]
@ -191,8 +198,9 @@ class WorkerBase(ABC):
if isinstance(output, Future):
output = output.wait()
output_work_item.refcount += 1
# output_work_item.refcount += 1
# TODO(jiangziyue) redesign lifecycle management for DAG scheduler
# all consumers have been satisfied, the work_item can be released
with self.output_list_condition_lock:
if output_work_item.refcount >= len(self.consumer_stage_ids):
@ -215,8 +223,10 @@ class WorkerBase(ABC):
self.partition_condition_lock.wait_for(lambda: hasattr(self, 'module_partition'))
return self.module_partition.state_dict()
def _make_args_kwargs(self, microbatch):
def _make_args_kwargs(self, microbatch, merge=False):
if isinstance(microbatch, dict):
if merge:
return list(microbatch.values()), {}
return [], microbatch
elif isinstance(microbatch, torch.Tensor):
return [microbatch], {}
@ -228,24 +238,70 @@ class WorkerBase(ABC):
kwargs.update(arg)
else:
args.append(arg)
if merge:
arg_lst = args
for arg in kwargs.values():
arg_lst.append(arg)
return arg_lst, {}
return args, kwargs
else:
raise TypeError(f"Input batch can be only dict, list, tuple or tensor, but receive {type(microbatch)}")
# just for first pp_rank
# TODO(jiangziyue) Consider whether this function should be protected by Lock in DAG env.
# TODO(jiangziyue) Define a Class for DAG.
def set_input(self, microbatch_id: int, microbatch: Tuple[Any], forward_only: bool):
assert self.consumer_stage_ids is not None
key = UniqueKey(microbatch_id, Phase.FORWARD)
output = self._get_future_by_device()
if not self.use_middleware():
# make args and kwargs
args, kwargs = self._make_args_kwargs(microbatch)
# make args and kwargs
args, kwargs = self._make_args_kwargs(microbatch)
work_item = WorkItem(self.pp_rank, Phase.FORWARD, args, kwargs, output, microbatch_id, None,
self.num_microbatches, forward_only)
with self.work_list_condition_lock:
self.work_list[key] = work_item
self.work_list_condition_lock.notify_all()
else:
# make args and kwargs
arg_lst, _ = self._make_args_kwargs(microbatch, merge=True)
work_item = WorkItem(self.pp_rank, Phase.FORWARD, args, kwargs, output, microbatch_id, None,
self.num_microbatches, forward_only)
with self.work_list_condition_lock:
self.work_list[key] = work_item
self.work_list_condition_lock.notify_all()
# first stage assign correct input into other stages
DAG = self.get_DAG()
DAG_node = DAG['input_partition']
self_input_offsets = []
recv_input_key = UniqueKey(microbatch_id, Phase.INPUT)
# notify rank which should receive extra input
offset = 0
for details in DAG_node.values():
for partition_name in details['output'].keys():
recv_rank = self.partition_name_to_pp_rank(partition_name)
if recv_rank == self.pp_rank:
self_input_offsets.append(offset)
elif recv_rank not in self.input_consumer_stage_ids:
self.input_consumer_stage_ids.append(recv_rank)
offset += 1
# set input for self rank
self_arg_lst = []
for off in self_input_offsets:
self_arg_lst.append(arg_lst[off])
work_item = WorkItem(self.pp_rank, Phase.FORWARD, self_arg_lst, {}, output, microbatch_id, None,
self.num_microbatches, forward_only)
with self.work_list_condition_lock:
self.work_list[key] = work_item
self.work_list_condition_lock.notify_all()
# put input tensor which other nodes need into output_list
work_item_remote = WorkItem(self.pp_rank, Phase.INPUT, [], {}, arg_lst, microbatch_id, None,
self.num_microbatches, forward_only)
with self.output_list_condition_lock:
self.output_list[recv_input_key] = work_item_remote
self.output_list_condition_lock.notify_all()
# just for last pp_rank
def set_labels(self, microbatch_id: int, microlabels: Any):
@ -268,33 +324,68 @@ class WorkerBase(ABC):
self.work_list[key] = work_item
self.work_list_condition_lock.notify_all()
# TODO(jiangziyue) Consider whether this function should be protected by Lock in DAG env.
def subscribe_producer(self, microbatch_id: int, forward_only: bool):
"""
You should call this function asynchronously
"""
assert self.producer_stage_ids is not None
producer_num = len(self.producer_stage_ids)
assert producer_num > 0, "only stage that has producers can subscribe producers"
stage_id = self.pp_rank
subscribe_forward_futures: List[Future] = [None] * producer_num
output = self._get_future_by_device()
if not self.use_middleware():
producer_num = len(self.producer_stage_ids)
subscribe_forward_futures: List[Future] = [None] * producer_num
for i in range(producer_num):
producer_stage_id = self.producer_stage_ids[i]
producer_output_key = UniqueKey(microbatch_id, Phase.FORWARD)
producer_worker_rref = self.pp_rank_to_worker_rref[producer_stage_id]
subscribe_forward_futures[i] = producer_worker_rref.rpc_async().get_output_by_key(producer_output_key)
else:
with self.work_list_condition_lock:
key = UniqueKey(microbatch_id, Phase.FORWARD)
if key in self.work_list:
return
for i in range(producer_num):
producer_stage_id = self.producer_stage_ids[i]
producer_output_key = UniqueKey(microbatch_id, Phase.FORWARD)
producer_worker_rref = self.pp_rank_to_worker_rref[producer_stage_id]
subscribe_forward_futures[i] = producer_worker_rref.rpc_async().get_output_by_key(producer_output_key)
producer_stage_ids = []
with self.producer_consumer_init_lock:
self.producer_consumer_init_lock.wait_for(lambda: self._producer_consumer_initialized)
producer_stage_ids = self.producer_stage_ids
producer_num = len(producer_stage_ids)
# TODO(jiangziyue) get single value instead of the whole output
if self.need_model_input():
producer_num += 1 # extra one(the last one) for input_tensor
subscribe_forward_futures: List[Future] = [None] * producer_num
# TODO(jiangziyue) get single value instead of the whole output
if self.need_model_input():
producer_stage_id = 0
producer_output_key = UniqueKey(microbatch_id, Phase.INPUT)
producer_worker_rref = self.pp_rank_to_worker_rref[producer_stage_id]
subscribe_forward_futures[0] = producer_worker_rref.rpc_async().get_output_by_key(producer_output_key, self.pp_rank)
for i in range(0, producer_num-1):
producer_stage_id = producer_stage_ids[i]
producer_output_key = UniqueKey(microbatch_id, Phase.FORWARD)
producer_worker_rref = self.pp_rank_to_worker_rref[producer_stage_id]
subscribe_forward_futures[i+1] = producer_worker_rref.rpc_async().get_output_by_key(producer_output_key, self.pp_rank)
else:
for i in range(producer_num):
producer_stage_id = producer_stage_ids[i]
producer_output_key = UniqueKey(microbatch_id, Phase.FORWARD)
producer_worker_rref = self.pp_rank_to_worker_rref[producer_stage_id]
#producer_partition_name = self.pp_rank_to_partition_name[producer_stage_id]
subscribe_forward_futures[i] = producer_worker_rref.rpc_async().get_output_by_key(producer_output_key, self.pp_rank)
work_item_from_producer = WorkItem(stage_id, Phase.FORWARD, subscribe_forward_futures, {}, output,
microbatch_id, None, self.num_microbatches, forward_only)
microbatch_id, None, self.num_microbatches, forward_only)
# add work_item to work_list
with self.work_list_condition_lock:
key = UniqueKey(microbatch_id, Phase.FORWARD)
assert key not in self.work_list
self.work_list[key] = work_item_from_producer
self.work_list_condition_lock.notify_all()
if key not in self.work_list:
self.work_list[key] = work_item_from_producer
self.work_list_condition_lock.notify_all()
def subscribe_consumer(self, microbatch_id: int):
"""
@ -334,13 +425,132 @@ class WorkerBase(ABC):
self.producer_stage_ids = []
self.consumer_stage_ids = []
# Just for demo
prev_rank = rank - 1
next_rank = rank + 1
if prev_rank >= 0:
self.producer_stage_ids.append(prev_rank)
if next_rank <= self.actual_stage_num - 1:
self.consumer_stage_ids.append(next_rank)
if not self.use_middleware():
# Just for demo
prev_rank = rank - 1
next_rank = rank + 1
if prev_rank >= 0:
self.producer_stage_ids.append(prev_rank)
if next_rank <= self.actual_stage_num - 1:
self.consumer_stage_ids.append(next_rank)
else:
self.input_consumer_stage_ids = []
DAG = self.get_DAG()
DAG_node_name = self.pp_rank_to_partition_name(rank)
DAG_node = DAG[DAG_node_name]
for partition_name in DAG_node['input'].keys():
if partition_name == 'MODEL_INPUT':
self._is_input = True
else:
prev_rank = self.partition_name_to_pp_rank(partition_name)
self.producer_stage_ids.append(prev_rank)
for partition_name in DAG_node['output'].keys():
if partition_name == 'MODEL_OUTPUT':
self._is_output = True
else:
next_rank = self.partition_name_to_pp_rank(partition_name)
self.consumer_stage_ids.append(next_rank)
# TODO(jiangziyue) Consider whether this function should be protected by Lock in DAG env.
with self.producer_consumer_init_lock:
self._producer_consumer_initialized = True
self.producer_consumer_init_lock.notify_all()
# TODO(jiangziyue) Define a Class for DAG.
def pp_rank_to_partition_name(self, pp_rank: int):
prefix = 'submod_'
partition_name = prefix + str(pp_rank)
return partition_name
# TODO(jiangziyue) Define a Class for DAG.
def partition_name_to_pp_rank(self, partition_name: str) -> int:
prefix = 'submod_'
pp_rank = int(partition_name.split(prefix)[-1])
return pp_rank
def get_DAG(self):
with self.partition_condition_lock:
self.partition_condition_lock.wait_for(lambda: hasattr(self, 'module_partition'))
if hasattr(self.module_partition, '_DAG'):
return self.module_partition._DAG
else:
return None
def use_middleware(self):
DAG = self.get_DAG()
return DAG is not None
# TODO(jiangziyue) get single value instead of the whole output
def _get_real_args_kwargs(self, args_or_kwargs):
if not self.use_middleware():
args_or_kwargs = pytree_map(args_or_kwargs, fn=lambda x: x.wait(), process_types=Future)
if args_or_kwargs is not None:
if isinstance(args_or_kwargs, dict):
pass
else:
flatten_args = []
pytree_map(args_or_kwargs, fn=lambda x: flatten_args.append(x), map_all=True)
args_or_kwargs = flatten_args
else:
args_or_kwargs = pytree_map(args_or_kwargs, fn=lambda x: x.wait(), process_types=Future)
if args_or_kwargs is not None:
if isinstance(args_or_kwargs, dict):
pass
else:
flatten_args = []
if self.is_first_stage():
pytree_map(args_or_kwargs, fn=lambda x: flatten_args.append(x), map_all=True)
# TODO get by offset
else:
DAG = self.get_DAG()
producer_outputs = {}
cur_DAG_node_name = self.pp_rank_to_partition_name(self.pp_rank)
#cur_DAG_node = DAG[self.pp_rank_to_partition_name(self.pp_rank)]
for i, args_from_one_mod in enumerate(args_or_kwargs):
producer_output_offsets = []
if self.need_model_input():
if i == 0:
producer_DAG_node = DAG['input_partition']
producer_partition_name = 'MODEL_INPUT'
offset = 0
for arg_info in producer_DAG_node.values():
if cur_DAG_node_name in arg_info['output']:
producer_output_offsets.append(offset)
offset += 1
else:
producer_rank = self.producer_stage_ids[i-1]
producer_partition_name = self.pp_rank_to_partition_name(producer_rank)
producer_DAG_node = DAG[producer_partition_name]
producer_output_offsets = producer_DAG_node['output'][cur_DAG_node_name]
else:
producer_rank = self.producer_stage_ids[i]
producer_partition_name = self.pp_rank_to_partition_name(producer_rank)
producer_DAG_node = DAG[producer_partition_name]
producer_output_offsets = producer_DAG_node['output'][cur_DAG_node_name]
if producer_partition_name != 'MODEL_INPUT' and DAG[producer_partition_name]['output_len'] == 1:
producer_outputs[producer_partition_name] = [args_from_one_mod]
else:
producer_outputs[producer_partition_name] = [args_from_one_mod[offset] for offset in producer_output_offsets]
cur_DAG_node_input = DAG[cur_DAG_node_name]['input']
def get_input_len(DAG_node_input):
res = 0
for offsets in DAG_node_input.values():
res += len(offsets)
return res
input_len = get_input_len(cur_DAG_node_input)
flatten_args = [None] * input_len
for producer_partition_name, args_input_offsets in cur_DAG_node_input.items():
for i, args_input_offset in enumerate(args_input_offsets):
flatten_args[args_input_offset] = producer_outputs[producer_partition_name][i]
args_or_kwargs = flatten_args
return args_or_kwargs
@abstractmethod
def _get_work_item_key(self) -> UniqueKey:
@ -353,6 +563,9 @@ class WorkerBase(ABC):
def is_last_stage(self):
return self.pp_rank == self.actual_stage_num - 1
def need_model_input(self):
return not self.is_first_stage() and self._is_input
def _default_data_process_func(self, args_kwargs):
if self.is_first_stage():
@ -390,11 +603,11 @@ class WorkerBase(ABC):
# parse and integrate args and kwargs
if is_first_stage:
args = get_real_args_kwargs(args)
kwargs = get_real_args_kwargs(kwargs)
args = self._get_real_args_kwargs(args)
kwargs = self._get_real_args_kwargs(kwargs)
args_kwargs = (args, kwargs)
else:
args_kwargs = get_real_args_kwargs(args)
args_kwargs = self._get_real_args_kwargs(args)
args, kwargs = data_process_func(args_kwargs)
@ -486,7 +699,7 @@ class WorkerBase(ABC):
# overlap recompute and future.wait
if not is_last_stage:
grad_tensors = get_real_args_kwargs(args)
grad_tensors = self._get_real_args_kwargs(args)
else:
grad_tensors = None
@ -569,7 +782,10 @@ class WorkerBase(ABC):
self._reset_context()
def initialize_optimizer(self, optimizer_class: type, **kwargs):
self.optimizer: optim.Optimizer = optimizer_class(self.module_partition.parameters(), **kwargs)
# TODO(jiangziyue) it's temporary code to deal with empty module partition.
# After tracer fixed, remove this part.
if len(list(self.module_partition.parameters())) > 0:
self.optimizer: optim.Optimizer = optimizer_class(self.module_partition.parameters(), **kwargs)
self.step_lock = threading.Lock()
self.step_lock.acquire()
@ -577,8 +793,11 @@ class WorkerBase(ABC):
self.step_lock.acquire()
def step(self):
self.optimizer.step()
self.optimizer.zero_grad()
# TODO(jiangziyue) it's temporary code to deal with empty module partition.
# After tracer fixed, remove this part.
if len(list(self.module_partition.parameters())) > 0:
self.optimizer.step()
self.optimizer.zero_grad()
self.step_lock.release()

View File

@ -20,6 +20,18 @@ def color_debug(text, prefix=' ', color='blue'):
color = color.upper()
print(getattr(Back, color), prefix, Style.RESET_ALL, text)
class MLP(nn.Module):
def __init__(self, dim: int, layers: int):
super().__init__()
self.layers = torch.nn.ModuleList()
for _ in range(layers):
self.layers.append(nn.Linear(dim, dim, bias=False))
def forward(self, x):
for layer in self.layers:
x = layer(x)
return x
class RpcTestModel(nn.Module):

View File

@ -0,0 +1,60 @@
import torch
from torch import nn
from colossalai.pipeline.rpc._pipeline_schedule import OneFOneBPipelineEngine
from colossalai.fx.passes.adding_split_node_pass import split_with_split_nodes_pass, balanced_split_pass
from colossalai.fx import ColoTracer
from rpc_test_utils import rpc_run, parse_args, MLP
from functools import partial
# global variable for model created
batch_size = 16
dim = 10
def create_partition_module(pp_rank: int, stage_num: int, model, data_kwargs):
model.eval()
tracer = ColoTracer()
meta_args = {k: v.to('meta') for k, v in data_kwargs.items()}
graph = tracer.trace(root=model, meta_args=meta_args)
gm = torch.fx.GraphModule(model, graph, model.__class__.__name__)
annotated_model = balanced_split_pass(gm, stage_num)
split_model, _ = split_with_split_nodes_pass(annotated_model, merge_output=True)
return list(split_model.children())[pp_rank]
def partition(data_kwargs: dict, pp_rank: int, chunk: int, stage_num: int):
torch.manual_seed(1024)
model = MLP(dim, stage_num * 3)
partition = create_partition_module(pp_rank, stage_num, model, data_kwargs)
return partition
def run_master(args):
torch.manual_seed(100)
epoch = args.epoch
device = args.device
stage_num = args.world_size
chunk = args.chunk
num_microbatches = args.num_microbatches
use_checkpoint = args.use_checkpoint
input_sample = torch.randn((batch_size, dim), device=device)
def data_gen():
x = torch.zeros((batch_size, dim))
kwargs = dict(x=x)
return kwargs
data_kwargs = data_gen()
engine = OneFOneBPipelineEngine(partition_fn=partial(partition, data_kwargs),
stage_num=stage_num,
num_microbatches=num_microbatches,
device=device,
chunk=chunk,
checkpoint=use_checkpoint)
for _ in range(epoch):
logits = engine.forward_backward({'x': input_sample}, forward_only=True)
if __name__ == "__main__":
args = parse_args()
rpc_run(args, run_master)