diff --git a/colossalai/pipeline/rpc/_pipeline_base.py b/colossalai/pipeline/rpc/_pipeline_base.py index ace834294..cbbd317e4 100644 --- a/colossalai/pipeline/rpc/_pipeline_base.py +++ b/colossalai/pipeline/rpc/_pipeline_base.py @@ -185,18 +185,7 @@ class WorkerBase(ABC): self.module_partition: nn.Module = partition_fn(*partition_args).to(device) self.partition_condition_lock.notify_all() - def sync_global_worker_rrefs(self, pp_rank_to_worker_rref: Dict[int, PyRRef]) -> None: - assert self.pp_rank_to_worker_rref is None, f"in rank {self.pp_rank}, worker has sync global workers rrefs" - assert pp_rank_to_worker_rref is not None, "stage_to_workers must be a dict instead of None" - self.pp_rank_to_worker_rref = pp_rank_to_worker_rref - - # for some schedule need the other worker's info to initialise partition (like Chimera) - # construction of partition is executed after the registion of pp_rank_to_worker_rref - self._initialize_partition() - - # res_use works for lifecycle counter, - # if ref_use is True, lifecycle won't add. - def get_output_by_key(self, key: UniqueKey, ref_use=False) -> Any: + def _get_output_all(self, key: UniqueKey, ref_use=False, rank=None): with self.output_list_condition_lock: self.output_list_condition_lock.wait_for(lambda: key in self.output_list) output_work_item = self.output_list[key] @@ -214,7 +203,8 @@ class WorkerBase(ABC): lifecycle += 1 elif output_work_item.phase == Phase.BACKWARD: lifecycle = len(self.get_producer_stage_ids()) - if self._is_last_step(output_work_item): # an extra reference for ensure_backward + if self.is_model_input() and self._is_last_step( + output_work_item): # an extra reference for ensure_backward lifecycle += 1 else: lifecycle = 0 @@ -230,6 +220,26 @@ class WorkerBase(ABC): return output + def sync_global_worker_rrefs(self, pp_rank_to_worker_rref: Dict[int, PyRRef]) -> None: + assert self.pp_rank_to_worker_rref is None, f"in rank {self.pp_rank}, worker has sync global workers rrefs" + assert pp_rank_to_worker_rref is not None, "stage_to_workers must be a dict instead of None" + self.pp_rank_to_worker_rref = pp_rank_to_worker_rref + + # for some schedule need the other worker's info to initialise partition (like Chimera) + # construction of partition is executed after the registion of pp_rank_to_worker_rref + self._initialize_partition() + + # res_use works for lifecycle counter, + # if ref_use is True, lifecycle won't add. + # offset supports get partial output to reduce comm costs. + def get_output_by_key(self, key: UniqueKey, ref_use=False, rank=None, offsets=None) -> Any: + output = self._get_output_all(key, ref_use, rank) + if offsets is None: # get all for non iterable output + return output + else: # get part for iterable output + output = [output[i] for i in offsets] + return output + def get_parameters(self) -> List[torch.Tensor]: return [p for p in self.module_partition.parameters()] @@ -361,22 +371,35 @@ class WorkerBase(ABC): producer_stage_id = 0 producer_output_key = UniqueKey(microbatch_id, Phase.INPUT) producer_worker_rref = self.pp_rank_to_worker_rref[producer_stage_id] - subscribe_forward_futures[0] = producer_worker_rref.rpc_async().get_output_by_key(producer_output_key) + offsets = self._get_input_offsets_by_index(target_index=0) + subscribe_forward_futures[0] = producer_worker_rref.rpc_async().get_output_by_key(producer_output_key, + rank=self.pp_rank, + offsets=offsets) for i in range(0, producer_num - 1): producer_stage_id = producer_stage_ids[i] producer_output_key = UniqueKey(microbatch_id, Phase.FORWARD) producer_worker_rref = self.pp_rank_to_worker_rref[producer_stage_id] - subscribe_forward_futures[i + 1] = producer_worker_rref.rpc_async().get_output_by_key( - producer_output_key) + target_index = i + 1 + offsets = self._get_input_offsets_by_index(target_index=target_index) + if offsets is not None and len(offsets) == 0: # no need to do rpc + subscribe_forward_futures[target_index] = [] + else: + subscribe_forward_futures[target_index] = producer_worker_rref.rpc_async().get_output_by_key( + producer_output_key, rank=self.pp_rank) else: for i in range(producer_num): producer_stage_id = producer_stage_ids[i] producer_output_key = UniqueKey(microbatch_id, Phase.FORWARD) producer_worker_rref = self.pp_rank_to_worker_rref[producer_stage_id] - subscribe_forward_futures[i] = producer_worker_rref.rpc_async().get_output_by_key( - producer_output_key) + target_index = i + offsets = self._get_input_offsets_by_index(target_index=target_index) + if offsets is not None and len(offsets) == 0: # no need to do rpc + subscribe_forward_futures[target_index] = [] + else: + subscribe_forward_futures[target_index] = producer_worker_rref.rpc_async().get_output_by_key( + producer_output_key, rank=self.pp_rank, offsets=offsets) work_item_from_producer = WorkItem(stage_id, Phase.FORWARD, subscribe_forward_futures, {}, output, microbatch_id, None, self.num_microbatches, forward_only) @@ -412,7 +435,13 @@ class WorkerBase(ABC): consumer_stage_id = consumer_stage_ids[i] consumer_output_key = UniqueKey(microbatch_id, Phase.BACKWARD) consumer_worker_rref = self.pp_rank_to_worker_rref[consumer_stage_id] - subscribe_backward_futures[i] = consumer_worker_rref.rpc_async().get_output_by_key(consumer_output_key) + target_index = i + offsets = self._get_output_offsets_by_index(target_index=target_index) + if offsets is not None and len(offsets) == 0: # no need to do rpc + subscribe_backward_futures[target_index] = [] + else: + subscribe_backward_futures[target_index] = consumer_worker_rref.rpc_async().get_output_by_key( + consumer_output_key, rank=self.pp_rank, offsets=offsets) # flatten args work_item_from_consumer = WorkItem(stage_id, Phase.BACKWARD, subscribe_backward_futures, {}, output, @@ -501,6 +530,75 @@ class WorkerBase(ABC): topo = self.get_topo() return topo is not None + def _get_input_offsets_by_index(self, target_index): + res = [] + topo: Topo = self.get_topo() + self_partition_id = self.pp_rank_to_partition_id(self.pp_rank, topo) + self_partition: Partition = topo.get_partition_by_id(self_partition_id) + model_input_partition_id = topo.get_input_partition_id() + input_vals = self_partition.get_input_vals() + producer_stage_ids = self.get_producer_stage_ids() + if self.need_model_input(): + # 0 for data from input batch + # >= 1 for data from prev stages + base = 1 + else: + # data from prev stages + base = 0 + for val in input_vals: + val_pos = val.get() + src_partition_id = val_pos.partition_id + src_offset = val_pos.offset + src_index = base + src_partition = topo.get_partition_by_id(src_partition_id) + output_len = len(src_partition.get_output_vals()) + # data from not-input partition + if src_partition_id != model_input_partition_id: + src_stage_id = self.partition_id_to_pp_rank(src_partition_id, topo) + src_index = base + for i, stage_id in enumerate(producer_stage_ids): + if stage_id == src_stage_id: + src_index += i + break + else: # data from input partition + src_index = 0 + # when output_len = 1, not iterable + if target_index == src_index: + if output_len == 1: + res = None # offset = None to get all outputs + return res + else: + res.append(src_offset) + return res + + def _get_output_offsets_by_index(self, target_index): + res = [] + topo: Topo = self.get_topo() + self_partition_id = self.pp_rank_to_partition_id(self.pp_rank, topo) + self_partition: Partition = topo.get_partition_by_id(self_partition_id) + output_vals = self_partition.get_output_vals() + consumer_stage_ids = self.get_consumer_stage_ids() + for val_list in output_vals: + # An output may be passed to many down stages. + target = None + for val_pos in val_list.get(): + dst_partition_id = val_pos.partition_id + dst_offset = val_pos.offset + dst_partition = topo.get_partition_by_id(dst_partition_id) + input_len = len(dst_partition.get_input_vals()) + dst_stage_id = self.partition_id_to_pp_rank(dst_partition_id, topo) + for i, stage_id in enumerate(consumer_stage_ids): + if stage_id == dst_stage_id: + dst_index = i + break + if target_index == dst_index: + if input_len == 1: + res = None # offset = None to get all outputs + return res + else: + res.append(dst_offset) + return res + # TODO(jiangziyue) get single value instead of the whole output def _get_real_args_kwargs_fwd(self, args_or_kwargs): if not self.use_middleware(): @@ -521,8 +619,7 @@ class WorkerBase(ABC): flatten_args = [] if self.is_first_stage(): pytree_map(args_or_kwargs, fn=lambda x: flatten_args.append(x), map_all=True) - # TODO get by offset - else: + else: # get by offset topo: Topo = self.get_topo() self_partition_id = self.pp_rank_to_partition_id(self.pp_rank, topo) self_partition: Partition = topo.get_partition_by_id(self_partition_id) @@ -557,7 +654,9 @@ class WorkerBase(ABC): if output_len == 1: target = args_or_kwargs[src_index] else: - target = args_or_kwargs[src_index][src_offset] + offsets = self._get_input_offsets_by_index(src_index) + real_offset = offsets.index(src_offset) + target = args_or_kwargs[src_index][real_offset] flatten_args.append(target) args_or_kwargs = flatten_args return args_or_kwargs @@ -574,10 +673,10 @@ class WorkerBase(ABC): pytree_map(args_or_kwargs, fn=lambda x: flatten_args.append(x), map_all=True) args_or_kwargs = flatten_args else: - args_or_kwargs = pytree_map(args_or_kwargs, fn=lambda x: x.wait(), process_types=Future) - if args_or_kwargs is not None: + for i, arg in enumerate(args_or_kwargs): + args_or_kwargs[i] = arg.wait() + if args_or_kwargs is not None: # get by offset flatten_args = [] - # TODO get by offset topo: Topo = self.get_topo() self_partition_id = self.pp_rank_to_partition_id(self.pp_rank, topo) self_partition: Partition = topo.get_partition_by_id(self_partition_id) @@ -599,7 +698,9 @@ class WorkerBase(ABC): if input_len == 1: part_grad = args_or_kwargs[dst_index] else: - part_grad = args_or_kwargs[dst_index][dst_offset] + offsets = self._get_output_offsets_by_index(dst_index) + real_offsets = offsets.index(dst_offset) + part_grad = args_or_kwargs[dst_index][real_offsets] if target is None: target = part_grad @@ -682,10 +783,6 @@ class WorkerBase(ABC): else: args_kwargs = self._get_real_args_kwargs_fwd(args) - # if not forward_only: - # pytree_map(args_kwargs, - # lambda x: x.requires_grad_(True) if torch.is_floating_point(x) else x.requires_grad_(False), - # process_types=torch.Tensor) args_kwargs = pyobj_map(args_kwargs, fn=lambda x: x.to(self.device).detach(), process_types=torch.Tensor) # torch rpc doesn't support args or rets in GPU @@ -752,14 +849,14 @@ class WorkerBase(ABC): stage_input_kwargs, stage_outputs, checkpoint=use_checkpoint) + consume_result = pyobj_map(consume_result, fn=lambda x: x.to('cpu'), + process_types=torch.Tensor) # torch rpc doesn't support args or rets in + # if not forward_only, do the backward if not forward_only: if is_last_stage: # if it is the last stage, trigger backward automatic self._begin_backward(microbatch_id) - consume_result = pyobj_map(consume_result, fn=lambda x: x.to('cpu'), - process_types=torch.Tensor) # torch rpc doesn't support args or rets in GPU - elif phase == Phase.BACKWARD: # remind its producer to get data before backward if not is_first_stage: @@ -803,10 +900,8 @@ class WorkerBase(ABC): filtered_grads.append(grad) stage_outputs = filtered_outputs - grad_tensors = filtered_grads - - grad_tensors = pyobj_map(grad_tensors, fn=lambda x: x.to(self.device), - process_types=torch.Tensor) # torch rpc doesn't support args or rets in GPU + grad_tensors = pyobj_map(filtered_grads, fn=lambda x: x.to(self.device), + process_types=torch.Tensor) # torch rpc doesn't support args or rets in GPU autograd.backward(stage_outputs, grad_tensors=grad_tensors) # collect grad of input tensor @@ -941,8 +1036,6 @@ class PipelineEngineBase(ABC, nn.Module): self.pp_rank_to_worker_rref: Dict[int, PyRRef] = dict() - self.step_futs: List[Future] = [] - self._check_argument() self._create_pp_rank_to_rpc_worker_id() self._create_pp_rank_to_module_partition_id() @@ -1058,9 +1151,14 @@ class PipelineEngineBase(ABC, nn.Module): ret_future[pp_rank][microbatch_id - actual_stage_num].wait() else: key = UniqueKey(microbatch_id - actual_stage_num, Phase.BACKWARD) + futs = [] for pp_rank in input_pp_ranks: worker_rref = self.pp_rank_to_worker_rref[pp_rank] - worker_rref.rpc_sync().get_output_by_key(key, ref_use=True) + fut = worker_rref.rpc_async().get_output_by_key(key, ref_use=True, offsets=[]) + futs.append(fut) + + for fut in futs: + fut.wait() def _create_ret_future(self, output_pp_ranks: List[int]) -> Dict[int, List[Future]]: num_microbatches = self.num_microbatches @@ -1087,10 +1185,16 @@ class PipelineEngineBase(ABC, nn.Module): def _ensure_backward(self, forward_only: bool, input_pp_ranks: List[int]): if not forward_only: + backward_result = [] for pp_rank in input_pp_ranks: worker_rref = self.pp_rank_to_worker_rref[pp_rank] key = UniqueKey(self.num_microbatches - 1, Phase.BACKWARD) - worker_rref.rpc_sync().get_output_by_key(key) + fut = worker_rref.rpc_async().get_output_by_key( + key, offsets=[]) # only ensure the res exists, no need for real data. + backward_result.append(fut) + + for fut in backward_result: + fut.wait() def _collect_forward_result(self, output_pp_ranks: List[int], ret_future: Dict[int, List[Future]]): forward_result = [] @@ -1109,12 +1213,13 @@ class PipelineEngineBase(ABC, nn.Module): def _reset_worker(self): actual_stage_num = self._get_actual_stage_num() + reset_futs: List[Future] = [] for pp_rank in range(actual_stage_num): worker_rref = self.pp_rank_to_worker_rref[pp_rank] fut = worker_rref.rpc_async().reset_context() - self.step_futs.append(fut) + reset_futs.append(fut) - for fut in self.step_futs: + for fut in reset_futs: fut.wait() def forward_backward(self, batch: torch.Tensor, labels: torch.Tensor = None, forward_only: bool = False): @@ -1141,7 +1246,7 @@ class PipelineEngineBase(ABC, nn.Module): for microbatch_id in range(num_microbatches): # control data input speed # to prevent exceed of wait limitations - self._consume_constraint(microbatch_id, forward_only, input_pp_ranks, output_pp_ranks, ret_future) + # self._consume_constraint(microbatch_id, forward_only, input_pp_ranks, output_pp_ranks, ret_future) batch_start = microbatch_size * microbatch_id batch_end = min(batch_start + microbatch_size, batch_length) @@ -1178,10 +1283,11 @@ class PipelineEngineBase(ABC, nn.Module): def step(self): actual_stage_num = self._get_actual_stage_num() + step_futs: List[Future] = [] for pp_rank in range(actual_stage_num): worker_rref = self.pp_rank_to_worker_rref[pp_rank] fut = worker_rref.rpc_async().step() - self.step_futs.append(fut) + step_futs.append(fut) - for fut in self.step_futs: + for fut in step_futs: fut.wait()