diff --git a/colossalai/fx/passes/adding_split_node_pass.py b/colossalai/fx/passes/adding_split_node_pass.py index a6911011e..503397878 100644 --- a/colossalai/fx/passes/adding_split_node_pass.py +++ b/colossalai/fx/passes/adding_split_node_pass.py @@ -117,7 +117,7 @@ def uniform_split_pass(gm: torch.fx.GraphModule, pp_size: int): return gm -def split_with_split_nodes_pass(annotated_gm: torch.fx.GraphModule): +def split_with_split_nodes_pass(annotated_gm: torch.fx.GraphModule, merge_output=False): # TODO(lyl): use partition IR to assign partition ID to each node. # Currently: analyzing graph -> annotate graph by inserting split node -> use split module pass to split graph # In future: graph to partitions -> analyzing partition IR -> recombining partitions to get best performance -> assign partition ID to each node @@ -129,7 +129,7 @@ def split_with_split_nodes_pass(annotated_gm: torch.fx.GraphModule): part_idx += 1 return part_idx - split_mod = split_module(annotated_gm, None, split_callback) + split_mod = split_module(annotated_gm, None, split_callback, merge_output) split_submodules = [] for name, submodule in split_mod.named_modules(): if isinstance(submodule, torch.fx.GraphModule): diff --git a/colossalai/fx/passes/utils.py b/colossalai/fx/passes/utils.py index b4d3d2086..fda010fd3 100644 --- a/colossalai/fx/passes/utils.py +++ b/colossalai/fx/passes/utils.py @@ -199,24 +199,17 @@ def find_user_in_partition(node, partitions, output_partitions=None, direct=Fals for partition in partitions: if node == partition: user_partition_names.append(partition.name) + # find user with getitem call else: for partition in partitions: if node in partition.args: user_partition_names.append(partition.name) - - is_output = False - def find_output(def_node, output_node): - nonlocal is_output - if def_node == output_node: - is_output = True - + if output_partitions is not None: output_node = output_partitions[0] - torch.fx.graph.map_arg(output_node.args[0], lambda n: find_output(node, n)) - - if is_output: - user_partition_names.append('MODEL_OUTPUT') + if node.op == output_node.op: + user_partition_names.append('MODEL_OUTPUT') if len(user_partition_names) > 0: return user_partition_names diff --git a/colossalai/pipeline/rpc/_pipeline_base.py b/colossalai/pipeline/rpc/_pipeline_base.py index 830e2bf2d..6a6c2379b 100644 --- a/colossalai/pipeline/rpc/_pipeline_base.py +++ b/colossalai/pipeline/rpc/_pipeline_base.py @@ -9,7 +9,7 @@ from typing import Any, Callable, Dict, List, Tuple import torch import torch.distributed.rpc as rpc from colossalai.pipeline.pipeline_process_group import ppg -from colossalai.pipeline.rpc.utils import (get_batch_lengths, get_real_args_kwargs, pytree_filter, pytree_map, +from colossalai.pipeline.rpc.utils import (get_batch_lengths, pytree_filter, pytree_map, split_batch, tensor_shape_list, type_detail) from torch import autograd, nn, optim from torch._C._distributed_rpc import PyRRef @@ -20,7 +20,7 @@ class Phase(Enum): FORWARD = 0 BACKWARD = 1 UPDATE = 2 - + INPUT = 3 class UniqueKey: __slots__ = ('microbatch_id', 'phase') @@ -128,6 +128,7 @@ class WorkerBase(ABC): # topology info self.producer_stage_ids: List[int] = None self.consumer_stage_ids: List[int] = None + self.input_consumer_stage_ids: List[int] = None # module partitions self.partition_fn = partition_fn @@ -135,6 +136,11 @@ class WorkerBase(ABC): self.criterion = criterion self.metric = metric + # middleware info + self._is_input = False + self._is_output = False + self._producer_consumer_initialized = False + # context to maintain loop self._initialize_context_container() @@ -164,6 +170,7 @@ class WorkerBase(ABC): self.work_list_condition_lock = threading.Condition(threading.Lock()) self.output_list_condition_lock = threading.Condition(threading.Lock()) self.label_lock = threading.Condition(threading.Lock()) + self.producer_consumer_init_lock = threading.Condition(threading.Lock()) def _initialize_partition(self): partition_fn = self.partition_fn @@ -182,7 +189,7 @@ class WorkerBase(ABC): # construction of partition is executed after the registion of pp_rank_to_worker_rref self._initialize_partition() - def get_output_by_key(self, key: UniqueKey) -> Any: + def get_output_by_key(self, key: UniqueKey, recv_rank=None) -> Any: with self.output_list_condition_lock: self.output_list_condition_lock.wait_for(lambda: key in self.output_list) output_work_item = self.output_list[key] @@ -191,8 +198,9 @@ class WorkerBase(ABC): if isinstance(output, Future): output = output.wait() - output_work_item.refcount += 1 + # output_work_item.refcount += 1 + # TODO(jiangziyue) redesign lifecycle management for DAG scheduler # all consumers have been satisfied, the work_item can be released with self.output_list_condition_lock: if output_work_item.refcount >= len(self.consumer_stage_ids): @@ -215,8 +223,10 @@ class WorkerBase(ABC): self.partition_condition_lock.wait_for(lambda: hasattr(self, 'module_partition')) return self.module_partition.state_dict() - def _make_args_kwargs(self, microbatch): + def _make_args_kwargs(self, microbatch, merge=False): if isinstance(microbatch, dict): + if merge: + return list(microbatch.values()), {} return [], microbatch elif isinstance(microbatch, torch.Tensor): return [microbatch], {} @@ -228,24 +238,70 @@ class WorkerBase(ABC): kwargs.update(arg) else: args.append(arg) + if merge: + arg_lst = args + for arg in kwargs.values(): + arg_lst.append(arg) + return arg_lst, {} return args, kwargs else: raise TypeError(f"Input batch can be only dict, list, tuple or tensor, but receive {type(microbatch)}") # just for first pp_rank + # TODO(jiangziyue) Consider whether this function should be protected by Lock in DAG env. + # TODO(jiangziyue) Define a Class for DAG. def set_input(self, microbatch_id: int, microbatch: Tuple[Any], forward_only: bool): assert self.consumer_stage_ids is not None key = UniqueKey(microbatch_id, Phase.FORWARD) output = self._get_future_by_device() + + if not self.use_middleware(): + # make args and kwargs + args, kwargs = self._make_args_kwargs(microbatch) - # make args and kwargs - args, kwargs = self._make_args_kwargs(microbatch) + work_item = WorkItem(self.pp_rank, Phase.FORWARD, args, kwargs, output, microbatch_id, None, + self.num_microbatches, forward_only) + with self.work_list_condition_lock: + self.work_list[key] = work_item + self.work_list_condition_lock.notify_all() + else: + # make args and kwargs + arg_lst, _ = self._make_args_kwargs(microbatch, merge=True) - work_item = WorkItem(self.pp_rank, Phase.FORWARD, args, kwargs, output, microbatch_id, None, - self.num_microbatches, forward_only) - with self.work_list_condition_lock: - self.work_list[key] = work_item - self.work_list_condition_lock.notify_all() + # first stage assign correct input into other stages + DAG = self.get_DAG() + DAG_node = DAG['input_partition'] + self_input_offsets = [] + recv_input_key = UniqueKey(microbatch_id, Phase.INPUT) + # notify rank which should receive extra input + offset = 0 + for details in DAG_node.values(): + for partition_name in details['output'].keys(): + recv_rank = self.partition_name_to_pp_rank(partition_name) + if recv_rank == self.pp_rank: + self_input_offsets.append(offset) + elif recv_rank not in self.input_consumer_stage_ids: + self.input_consumer_stage_ids.append(recv_rank) + offset += 1 + + # set input for self rank + self_arg_lst = [] + for off in self_input_offsets: + self_arg_lst.append(arg_lst[off]) + + work_item = WorkItem(self.pp_rank, Phase.FORWARD, self_arg_lst, {}, output, microbatch_id, None, + self.num_microbatches, forward_only) + with self.work_list_condition_lock: + self.work_list[key] = work_item + self.work_list_condition_lock.notify_all() + + # put input tensor which other nodes need into output_list + work_item_remote = WorkItem(self.pp_rank, Phase.INPUT, [], {}, arg_lst, microbatch_id, None, + self.num_microbatches, forward_only) + + with self.output_list_condition_lock: + self.output_list[recv_input_key] = work_item_remote + self.output_list_condition_lock.notify_all() # just for last pp_rank def set_labels(self, microbatch_id: int, microlabels: Any): @@ -268,33 +324,68 @@ class WorkerBase(ABC): self.work_list[key] = work_item self.work_list_condition_lock.notify_all() + # TODO(jiangziyue) Consider whether this function should be protected by Lock in DAG env. def subscribe_producer(self, microbatch_id: int, forward_only: bool): """ You should call this function asynchronously """ - assert self.producer_stage_ids is not None - producer_num = len(self.producer_stage_ids) - assert producer_num > 0, "only stage that has producers can subscribe producers" - stage_id = self.pp_rank - subscribe_forward_futures: List[Future] = [None] * producer_num output = self._get_future_by_device() + if not self.use_middleware(): + producer_num = len(self.producer_stage_ids) + subscribe_forward_futures: List[Future] = [None] * producer_num + for i in range(producer_num): + producer_stage_id = self.producer_stage_ids[i] + producer_output_key = UniqueKey(microbatch_id, Phase.FORWARD) + producer_worker_rref = self.pp_rank_to_worker_rref[producer_stage_id] + subscribe_forward_futures[i] = producer_worker_rref.rpc_async().get_output_by_key(producer_output_key) + else: + with self.work_list_condition_lock: + key = UniqueKey(microbatch_id, Phase.FORWARD) + if key in self.work_list: + return - for i in range(producer_num): - producer_stage_id = self.producer_stage_ids[i] - producer_output_key = UniqueKey(microbatch_id, Phase.FORWARD) - producer_worker_rref = self.pp_rank_to_worker_rref[producer_stage_id] - subscribe_forward_futures[i] = producer_worker_rref.rpc_async().get_output_by_key(producer_output_key) + producer_stage_ids = [] + with self.producer_consumer_init_lock: + self.producer_consumer_init_lock.wait_for(lambda: self._producer_consumer_initialized) + producer_stage_ids = self.producer_stage_ids + producer_num = len(producer_stage_ids) + + # TODO(jiangziyue) get single value instead of the whole output + if self.need_model_input(): + producer_num += 1 # extra one(the last one) for input_tensor + subscribe_forward_futures: List[Future] = [None] * producer_num + + # TODO(jiangziyue) get single value instead of the whole output + if self.need_model_input(): + producer_stage_id = 0 + producer_output_key = UniqueKey(microbatch_id, Phase.INPUT) + producer_worker_rref = self.pp_rank_to_worker_rref[producer_stage_id] + subscribe_forward_futures[0] = producer_worker_rref.rpc_async().get_output_by_key(producer_output_key, self.pp_rank) + + for i in range(0, producer_num-1): + producer_stage_id = producer_stage_ids[i] + producer_output_key = UniqueKey(microbatch_id, Phase.FORWARD) + producer_worker_rref = self.pp_rank_to_worker_rref[producer_stage_id] + subscribe_forward_futures[i+1] = producer_worker_rref.rpc_async().get_output_by_key(producer_output_key, self.pp_rank) + + else: + for i in range(producer_num): + producer_stage_id = producer_stage_ids[i] + producer_output_key = UniqueKey(microbatch_id, Phase.FORWARD) + producer_worker_rref = self.pp_rank_to_worker_rref[producer_stage_id] + #producer_partition_name = self.pp_rank_to_partition_name[producer_stage_id] + subscribe_forward_futures[i] = producer_worker_rref.rpc_async().get_output_by_key(producer_output_key, self.pp_rank) work_item_from_producer = WorkItem(stage_id, Phase.FORWARD, subscribe_forward_futures, {}, output, - microbatch_id, None, self.num_microbatches, forward_only) + microbatch_id, None, self.num_microbatches, forward_only) # add work_item to work_list with self.work_list_condition_lock: key = UniqueKey(microbatch_id, Phase.FORWARD) - assert key not in self.work_list - self.work_list[key] = work_item_from_producer - self.work_list_condition_lock.notify_all() + if key not in self.work_list: + self.work_list[key] = work_item_from_producer + self.work_list_condition_lock.notify_all() def subscribe_consumer(self, microbatch_id: int): """ @@ -334,13 +425,132 @@ class WorkerBase(ABC): self.producer_stage_ids = [] self.consumer_stage_ids = [] - # Just for demo - prev_rank = rank - 1 - next_rank = rank + 1 - if prev_rank >= 0: - self.producer_stage_ids.append(prev_rank) - if next_rank <= self.actual_stage_num - 1: - self.consumer_stage_ids.append(next_rank) + if not self.use_middleware(): + # Just for demo + prev_rank = rank - 1 + next_rank = rank + 1 + if prev_rank >= 0: + self.producer_stage_ids.append(prev_rank) + if next_rank <= self.actual_stage_num - 1: + self.consumer_stage_ids.append(next_rank) + else: + self.input_consumer_stage_ids = [] + DAG = self.get_DAG() + DAG_node_name = self.pp_rank_to_partition_name(rank) + DAG_node = DAG[DAG_node_name] + for partition_name in DAG_node['input'].keys(): + if partition_name == 'MODEL_INPUT': + self._is_input = True + else: + prev_rank = self.partition_name_to_pp_rank(partition_name) + self.producer_stage_ids.append(prev_rank) + + for partition_name in DAG_node['output'].keys(): + if partition_name == 'MODEL_OUTPUT': + self._is_output = True + else: + next_rank = self.partition_name_to_pp_rank(partition_name) + self.consumer_stage_ids.append(next_rank) + + # TODO(jiangziyue) Consider whether this function should be protected by Lock in DAG env. + with self.producer_consumer_init_lock: + self._producer_consumer_initialized = True + self.producer_consumer_init_lock.notify_all() + + # TODO(jiangziyue) Define a Class for DAG. + def pp_rank_to_partition_name(self, pp_rank: int): + prefix = 'submod_' + partition_name = prefix + str(pp_rank) + return partition_name + + # TODO(jiangziyue) Define a Class for DAG. + def partition_name_to_pp_rank(self, partition_name: str) -> int: + prefix = 'submod_' + pp_rank = int(partition_name.split(prefix)[-1]) + return pp_rank + + def get_DAG(self): + with self.partition_condition_lock: + self.partition_condition_lock.wait_for(lambda: hasattr(self, 'module_partition')) + if hasattr(self.module_partition, '_DAG'): + return self.module_partition._DAG + else: + return None + + def use_middleware(self): + DAG = self.get_DAG() + return DAG is not None + + # TODO(jiangziyue) get single value instead of the whole output + def _get_real_args_kwargs(self, args_or_kwargs): + if not self.use_middleware(): + args_or_kwargs = pytree_map(args_or_kwargs, fn=lambda x: x.wait(), process_types=Future) + if args_or_kwargs is not None: + if isinstance(args_or_kwargs, dict): + pass + else: + flatten_args = [] + pytree_map(args_or_kwargs, fn=lambda x: flatten_args.append(x), map_all=True) + args_or_kwargs = flatten_args + else: + args_or_kwargs = pytree_map(args_or_kwargs, fn=lambda x: x.wait(), process_types=Future) + if args_or_kwargs is not None: + if isinstance(args_or_kwargs, dict): + pass + else: + flatten_args = [] + if self.is_first_stage(): + pytree_map(args_or_kwargs, fn=lambda x: flatten_args.append(x), map_all=True) + # TODO get by offset + else: + DAG = self.get_DAG() + producer_outputs = {} + cur_DAG_node_name = self.pp_rank_to_partition_name(self.pp_rank) + #cur_DAG_node = DAG[self.pp_rank_to_partition_name(self.pp_rank)] + for i, args_from_one_mod in enumerate(args_or_kwargs): + producer_output_offsets = [] + if self.need_model_input(): + if i == 0: + producer_DAG_node = DAG['input_partition'] + producer_partition_name = 'MODEL_INPUT' + offset = 0 + for arg_info in producer_DAG_node.values(): + if cur_DAG_node_name in arg_info['output']: + producer_output_offsets.append(offset) + offset += 1 + else: + producer_rank = self.producer_stage_ids[i-1] + producer_partition_name = self.pp_rank_to_partition_name(producer_rank) + producer_DAG_node = DAG[producer_partition_name] + producer_output_offsets = producer_DAG_node['output'][cur_DAG_node_name] + + else: + producer_rank = self.producer_stage_ids[i] + producer_partition_name = self.pp_rank_to_partition_name(producer_rank) + producer_DAG_node = DAG[producer_partition_name] + producer_output_offsets = producer_DAG_node['output'][cur_DAG_node_name] + + if producer_partition_name != 'MODEL_INPUT' and DAG[producer_partition_name]['output_len'] == 1: + producer_outputs[producer_partition_name] = [args_from_one_mod] + else: + producer_outputs[producer_partition_name] = [args_from_one_mod[offset] for offset in producer_output_offsets] + + cur_DAG_node_input = DAG[cur_DAG_node_name]['input'] + + def get_input_len(DAG_node_input): + res = 0 + for offsets in DAG_node_input.values(): + res += len(offsets) + return res + + input_len = get_input_len(cur_DAG_node_input) + flatten_args = [None] * input_len + for producer_partition_name, args_input_offsets in cur_DAG_node_input.items(): + for i, args_input_offset in enumerate(args_input_offsets): + flatten_args[args_input_offset] = producer_outputs[producer_partition_name][i] + + args_or_kwargs = flatten_args + return args_or_kwargs @abstractmethod def _get_work_item_key(self) -> UniqueKey: @@ -353,6 +563,9 @@ class WorkerBase(ABC): def is_last_stage(self): return self.pp_rank == self.actual_stage_num - 1 + + def need_model_input(self): + return not self.is_first_stage() and self._is_input def _default_data_process_func(self, args_kwargs): if self.is_first_stage(): @@ -390,11 +603,11 @@ class WorkerBase(ABC): # parse and integrate args and kwargs if is_first_stage: - args = get_real_args_kwargs(args) - kwargs = get_real_args_kwargs(kwargs) + args = self._get_real_args_kwargs(args) + kwargs = self._get_real_args_kwargs(kwargs) args_kwargs = (args, kwargs) else: - args_kwargs = get_real_args_kwargs(args) + args_kwargs = self._get_real_args_kwargs(args) args, kwargs = data_process_func(args_kwargs) @@ -486,7 +699,7 @@ class WorkerBase(ABC): # overlap recompute and future.wait if not is_last_stage: - grad_tensors = get_real_args_kwargs(args) + grad_tensors = self._get_real_args_kwargs(args) else: grad_tensors = None @@ -569,7 +782,10 @@ class WorkerBase(ABC): self._reset_context() def initialize_optimizer(self, optimizer_class: type, **kwargs): - self.optimizer: optim.Optimizer = optimizer_class(self.module_partition.parameters(), **kwargs) + # TODO(jiangziyue) it's temporary code to deal with empty module partition. + # After tracer fixed, remove this part. + if len(list(self.module_partition.parameters())) > 0: + self.optimizer: optim.Optimizer = optimizer_class(self.module_partition.parameters(), **kwargs) self.step_lock = threading.Lock() self.step_lock.acquire() @@ -577,8 +793,11 @@ class WorkerBase(ABC): self.step_lock.acquire() def step(self): - self.optimizer.step() - self.optimizer.zero_grad() + # TODO(jiangziyue) it's temporary code to deal with empty module partition. + # After tracer fixed, remove this part. + if len(list(self.module_partition.parameters())) > 0: + self.optimizer.step() + self.optimizer.zero_grad() self.step_lock.release() diff --git a/tests/test_pipeline/rpc_test_utils.py b/tests/test_pipeline/rpc_test_utils.py index fe0333bde..f1a4116be 100644 --- a/tests/test_pipeline/rpc_test_utils.py +++ b/tests/test_pipeline/rpc_test_utils.py @@ -20,6 +20,18 @@ def color_debug(text, prefix=' ', color='blue'): color = color.upper() print(getattr(Back, color), prefix, Style.RESET_ALL, text) +class MLP(nn.Module): + def __init__(self, dim: int, layers: int): + super().__init__() + self.layers = torch.nn.ModuleList() + + for _ in range(layers): + self.layers.append(nn.Linear(dim, dim, bias=False)) + + def forward(self, x): + for layer in self.layers: + x = layer(x) + return x class RpcTestModel(nn.Module): diff --git a/tests/test_pipeline/test_middleware_1f1b.py b/tests/test_pipeline/test_middleware_1f1b.py new file mode 100644 index 000000000..ea9a3c16e --- /dev/null +++ b/tests/test_pipeline/test_middleware_1f1b.py @@ -0,0 +1,60 @@ +import torch +from torch import nn + +from colossalai.pipeline.rpc._pipeline_schedule import OneFOneBPipelineEngine +from colossalai.fx.passes.adding_split_node_pass import split_with_split_nodes_pass, balanced_split_pass +from colossalai.fx import ColoTracer +from rpc_test_utils import rpc_run, parse_args, MLP +from functools import partial + +# global variable for model created +batch_size = 16 +dim = 10 + +def create_partition_module(pp_rank: int, stage_num: int, model, data_kwargs): + model.eval() + tracer = ColoTracer() + meta_args = {k: v.to('meta') for k, v in data_kwargs.items()} + graph = tracer.trace(root=model, meta_args=meta_args) + gm = torch.fx.GraphModule(model, graph, model.__class__.__name__) + annotated_model = balanced_split_pass(gm, stage_num) + split_model, _ = split_with_split_nodes_pass(annotated_model, merge_output=True) + return list(split_model.children())[pp_rank] + +def partition(data_kwargs: dict, pp_rank: int, chunk: int, stage_num: int): + torch.manual_seed(1024) + model = MLP(dim, stage_num * 3) + partition = create_partition_module(pp_rank, stage_num, model, data_kwargs) + return partition + +def run_master(args): + torch.manual_seed(100) + + epoch = args.epoch + device = args.device + stage_num = args.world_size + chunk = args.chunk + num_microbatches = args.num_microbatches + use_checkpoint = args.use_checkpoint + + input_sample = torch.randn((batch_size, dim), device=device) + + def data_gen(): + x = torch.zeros((batch_size, dim)) + kwargs = dict(x=x) + return kwargs + + data_kwargs = data_gen() + engine = OneFOneBPipelineEngine(partition_fn=partial(partition, data_kwargs), + stage_num=stage_num, + num_microbatches=num_microbatches, + device=device, + chunk=chunk, + checkpoint=use_checkpoint) + + for _ in range(epoch): + logits = engine.forward_backward({'x': input_sample}, forward_only=True) + +if __name__ == "__main__": + args = parse_args() + rpc_run(args, run_master) \ No newline at end of file