From 59e343328d7111af84e988f9d20fb297f786726c Mon Sep 17 00:00:00 2001 From: Ziyue Jiang Date: Fri, 23 Dec 2022 11:38:43 +0800 Subject: [PATCH] [Pipeline Middleware ] Fix deadlock when num_microbatch=num_stage (#2156) * add splitter * polish code * remove comment * fix async nan by moving to cpu first Co-authored-by: Ziyue Jiang --- .../fx/passes/adding_split_node_pass.py | 24 +++++ colossalai/pipeline/rpc/_pipeline_base.py | 88 ++++++++----------- colossalai/pipeline/rpc/_pipeline_schedule.py | 7 +- colossalai/pipeline/rpc/utils.py | 23 ++++- 4 files changed, 84 insertions(+), 58 deletions(-) diff --git a/colossalai/fx/passes/adding_split_node_pass.py b/colossalai/fx/passes/adding_split_node_pass.py index 503397878..373d20c51 100644 --- a/colossalai/fx/passes/adding_split_node_pass.py +++ b/colossalai/fx/passes/adding_split_node_pass.py @@ -9,6 +9,30 @@ def pipe_split(): pass +def avgnode_split_pass(gm: torch.fx.GraphModule, pp_size: int): + """ + In avgnode_split_pass, simpliy split graph by node number. + """ + mod_graph = gm.graph + avg_num_node = len(mod_graph.nodes) // pp_size + accumulate_num_node = 0 + for node in mod_graph.nodes: + if pp_size <= 1: + break + accumulate_num_node += 1 + if accumulate_num_node >= avg_num_node: + accumulate_num_node = 0 + pp_size -= 1 + if node.next.op == 'output': + with mod_graph.inserting_before(node): + split_node = mod_graph.create_node('call_function', pipe_split) + else: + with mod_graph.inserting_after(node): + split_node = mod_graph.create_node('call_function', pipe_split) + gm.recompile() + return gm + + def balanced_split_pass(gm: torch.fx.GraphModule, pp_size: int): """ In balanced_split_pass, we split module by the size of parameters(weights+bias). diff --git a/colossalai/pipeline/rpc/_pipeline_base.py b/colossalai/pipeline/rpc/_pipeline_base.py index ae1cbb0c4..ace834294 100644 --- a/colossalai/pipeline/rpc/_pipeline_base.py +++ b/colossalai/pipeline/rpc/_pipeline_base.py @@ -16,6 +16,7 @@ from colossalai.pipeline.middleware import Partition, PartitionInputVal, Partiti from colossalai.pipeline.pipeline_process_group import ppg from colossalai.pipeline.rpc.utils import ( get_batch_lengths, + pyobj_map, pytree_filter, pytree_map, split_batch, @@ -199,38 +200,30 @@ class WorkerBase(ABC): with self.output_list_condition_lock: self.output_list_condition_lock.wait_for(lambda: key in self.output_list) output_work_item = self.output_list[key] - self.output_list.pop(key) + output = output_work_item.output + if not ref_use and output_work_item.phase != Phase.INPUT: + self.output_list.pop(key) - if not ref_use: + if not ref_use and output_work_item.phase != Phase.INPUT: output_work_item.refcount += 1 - refcount = output_work_item.refcount - output = output_work_item.output - - if output_work_item.phase == Phase.FORWARD: + refcount = output_work_item.refcount # lifecycle management for DAG scheduler - lifecycle = len(self.get_consumer_stage_ids()) - if self.is_model_output(): # an extra reference for scheduler collecting results - lifecycle += 1 + if output_work_item.phase == Phase.FORWARD: + lifecycle = len(self.get_consumer_stage_ids()) + if self.is_model_output(): # an extra reference for scheduler collecting results + lifecycle += 1 + elif output_work_item.phase == Phase.BACKWARD: + lifecycle = len(self.get_producer_stage_ids()) + if self._is_last_step(output_work_item): # an extra reference for ensure_backward + lifecycle += 1 + else: + lifecycle = 0 + refcount = 0 + with self.output_list_condition_lock: - # all consumers have been satisfied, the work_item can be released - # or put it into work list again. if refcount < lifecycle: self.output_list[key] = output_work_item self.output_list_condition_lock.notify_all() - elif output_work_item.phase == Phase.BACKWARD: - lifecycle = len(self.get_producer_stage_ids()) - if self._is_last_step(output_work_item): - lifecycle += 1 # an extra reference for scheduler collecting results - with self.output_list_condition_lock: - # all producers have been satisfied, the work_item can be released - # or put it into work list again. - if refcount < lifecycle: - self.output_list[key] = output_work_item - self.output_list_condition_lock.notify_all() - else: - with self.output_list_condition_lock: - self.output_list[key] = output_work_item - self.output_list_condition_lock.notify_all() if isinstance(output, Future): output = output.wait() @@ -689,10 +682,12 @@ class WorkerBase(ABC): else: args_kwargs = self._get_real_args_kwargs_fwd(args) - if not forward_only: - pytree_map(args_kwargs, - lambda x: x.requires_grad_(True) if torch.is_floating_point(x) else x.requires_grad_(False), - process_types=torch.Tensor) + # if not forward_only: + # pytree_map(args_kwargs, + # lambda x: x.requires_grad_(True) if torch.is_floating_point(x) else x.requires_grad_(False), + # process_types=torch.Tensor) + args_kwargs = pyobj_map(args_kwargs, fn=lambda x: x.to(self.device).detach(), + process_types=torch.Tensor) # torch rpc doesn't support args or rets in GPU args, kwargs = data_process_func(args_kwargs) @@ -762,6 +757,9 @@ class WorkerBase(ABC): if is_last_stage: # if it is the last stage, trigger backward automatic self._begin_backward(microbatch_id) + consume_result = pyobj_map(consume_result, fn=lambda x: x.to('cpu'), + process_types=torch.Tensor) # torch rpc doesn't support args or rets in GPU + elif phase == Phase.BACKWARD: # remind its producer to get data before backward if not is_first_stage: @@ -807,6 +805,8 @@ class WorkerBase(ABC): stage_outputs = filtered_outputs grad_tensors = filtered_grads + grad_tensors = pyobj_map(grad_tensors, fn=lambda x: x.to(self.device), + process_types=torch.Tensor) # torch rpc doesn't support args or rets in GPU autograd.backward(stage_outputs, grad_tensors=grad_tensors) # collect grad of input tensor @@ -818,6 +818,9 @@ class WorkerBase(ABC): consume_result.append(arg.grad) else: consume_result.append(None) + consume_result = pyobj_map( + consume_result, fn=lambda x: x.to('cpu'), + process_types=torch.Tensor) # torch rpc doesn't support args or rets in GPU else: raise TypeError(f"Unknown phase appears in _consume_work_item_by_phase {phase}") @@ -882,9 +885,6 @@ class WorkerBase(ABC): # if is last step in one batch reset context and do step if self._is_last_step(work_item): - self._hook_before_step() - if hasattr(self, 'optimizer') and not work_item.forward_only: - self.step() self._wait_for_reset() # reset context and resume loop @@ -904,23 +904,12 @@ class WorkerBase(ABC): self.reset_condition.notify_all() def initialize_optimizer(self, optimizer_class: type, **kwargs): - # TODO(jiangziyue) it's temporary code to deal with empty module partition. - # After tracer fixed, remove this part. - if len(list(self.module_partition.parameters())) > 0: - self.optimizer: optim.Optimizer = optimizer_class(self.module_partition.parameters(), **kwargs) - self.step_lock = threading.Lock() - self.step_lock.acquire() - - def wait_for_step(self): - self.step_lock.acquire() + self.optimizer: optim.Optimizer = optimizer_class(self.module_partition.parameters(), **kwargs) def step(self): - # TODO(jiangziyue) it's temporary code to deal with empty module partition. - # After tracer fixed, remove this part. - if len(list(self.module_partition.parameters())) > 0: - self.optimizer.step() - self.optimizer.zero_grad() - self.step_lock.release() + self._hook_before_step() + self.optimizer.step() + self.optimizer.zero_grad() class PipelineEngineBase(ABC, nn.Module): @@ -1176,10 +1165,7 @@ class PipelineEngineBase(ABC, nn.Module): forward_result = self._collect_forward_result(output_pp_ranks, ret_future) if not forward_only and hasattr(self, 'optimizer_class'): - # wait for all step - for pp_rank in self.pp_rank_to_worker_rref: - worker_rref = self.pp_rank_to_worker_rref[pp_rank] - worker_rref.rpc_sync().wait_for_step() + self.step() self._reset_worker() # reset worker attributes for next batch return forward_result diff --git a/colossalai/pipeline/rpc/_pipeline_schedule.py b/colossalai/pipeline/rpc/_pipeline_schedule.py index 555955583..e6aa961f1 100644 --- a/colossalai/pipeline/rpc/_pipeline_schedule.py +++ b/colossalai/pipeline/rpc/_pipeline_schedule.py @@ -3,11 +3,12 @@ from typing import Callable, Dict, List import torch import torch.distributed as dist -from colossalai.pipeline.pipeline_process_group import ppg -from colossalai.pipeline.rpc._pipeline_base import (Phase, PipelineEngineBase, UniqueKey, WorkerBase, WorkItem) from torch._C._distributed_rpc import PyRRef from torch.futures import Future +from colossalai.pipeline.pipeline_process_group import ppg +from colossalai.pipeline.rpc._pipeline_base import Phase, PipelineEngineBase, UniqueKey, WorkerBase, WorkItem + # Implementation of different Pipeline schedule # Worker defines the worker for each stage # PipelineEngine is the class for use @@ -86,7 +87,7 @@ class OneFOneBWorker(WorkerBase): outstanding_min = actual_stage_num - pp_rank - 1 outstanding_max = actual_stage_num - pp_rank self.outstanding_range = (outstanding_min, outstanding_max) - elif target_key.microbatch_id == num_microbatches - 1: + if target_key.microbatch_id == num_microbatches - 1: self.outstanding_range = (0, 0) return target_key diff --git a/colossalai/pipeline/rpc/utils.py b/colossalai/pipeline/rpc/utils.py index 77d601173..4310b3afe 100644 --- a/colossalai/pipeline/rpc/utils.py +++ b/colossalai/pipeline/rpc/utils.py @@ -6,11 +6,25 @@ from typing import Any, Callable, Dict, List, Tuple, Type, Union import torch import torch.distributed.rpc as rpc import torch.multiprocessing as mp -from colossalai.initialize import launch -from colossalai.pipeline.pipeline_process_group import ppg from torch._C._distributed_rpc import _is_current_rpc_agent_set from torch.futures import Future +from colossalai.initialize import launch +from colossalai.pipeline.pipeline_process_group import ppg + + +def pyobj_map(obj: Any, fn: Callable, process_types: Union[Type, Tuple[Type]] = ()) -> Any: + if isinstance(obj, process_types): + return fn(obj) + elif type(obj) is dict: + return {k: pyobj_map(obj[k], fn, process_types) for k in obj} + elif type(obj) is tuple: + return tuple(pyobj_map(o, fn, process_types) for o in obj) + elif type(obj) is list: + return list(pyobj_map(o, fn, process_types) for o in obj) + else: + return obj + def pytree_map(obj: Any, fn: Callable, process_types: Union[Type, Tuple[Type]] = (), map_all: bool = False) -> Any: """process object recursively, like pytree @@ -19,10 +33,10 @@ def pytree_map(obj: Any, fn: Callable, process_types: Union[Type, Tuple[Type]] = obj (:class:`Any`): object to process fn (:class:`Callable`): a function to process subobject in obj process_types (:class: `type | tuple[type]`): types to determine the type to process - map_all (:class: `bool`): if map_all is True, then any type of element will use fn + map_all (:class: `bool`): if map_all is True, then any type of element will use fn Returns: - :class:`Any`: returns have the same structure of `obj` and type in process_types after map of `fn` + :class:`Any`: returns have the same structure of `obj` and type in process_types after map of `fn` """ if isinstance(obj, dict): return {k: pytree_map(obj[k], fn, process_types, map_all) for k in obj} @@ -57,6 +71,7 @@ def split_batch(batch: Any, start, stop, device: str): def type_detail(obj): return pytree_map(obj, lambda x: type(x), map_all=True) + def pytree_filter(fn, obj, process_types): if obj is None: return None