[Pipeline Middleware ] Fix deadlock when num_microbatch=num_stage (#2156)

* add splitter

* polish code

* remove comment

* fix async nan by moving to cpu first

Co-authored-by: Ziyue Jiang <ziyue.jiang@gmail.com>
This commit is contained in:
Ziyue Jiang
2022-12-23 11:38:43 +08:00
committed by GitHub
parent 937f404253
commit 59e343328d
4 changed files with 84 additions and 58 deletions

View File

@@ -3,11 +3,12 @@ from typing import Callable, Dict, List
import torch
import torch.distributed as dist
from colossalai.pipeline.pipeline_process_group import ppg
from colossalai.pipeline.rpc._pipeline_base import (Phase, PipelineEngineBase, UniqueKey, WorkerBase, WorkItem)
from torch._C._distributed_rpc import PyRRef
from torch.futures import Future
from colossalai.pipeline.pipeline_process_group import ppg
from colossalai.pipeline.rpc._pipeline_base import Phase, PipelineEngineBase, UniqueKey, WorkerBase, WorkItem
# Implementation of different Pipeline schedule
# <strategy>Worker defines the worker for each stage
# <strategy>PipelineEngine is the class for use
@@ -86,7 +87,7 @@ class OneFOneBWorker(WorkerBase):
outstanding_min = actual_stage_num - pp_rank - 1
outstanding_max = actual_stage_num - pp_rank
self.outstanding_range = (outstanding_min, outstanding_max)
elif target_key.microbatch_id == num_microbatches - 1:
if target_key.microbatch_id == num_microbatches - 1:
self.outstanding_range = (0, 0)
return target_key