mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-10-07 18:15:56 +00:00
[Pipeline Middleware ] Fix deadlock when num_microbatch=num_stage (#2156)
* add splitter * polish code * remove comment * fix async nan by moving to cpu first Co-authored-by: Ziyue Jiang <ziyue.jiang@gmail.com>
This commit is contained in:
@@ -3,11 +3,12 @@ from typing import Callable, Dict, List
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from colossalai.pipeline.pipeline_process_group import ppg
|
||||
from colossalai.pipeline.rpc._pipeline_base import (Phase, PipelineEngineBase, UniqueKey, WorkerBase, WorkItem)
|
||||
from torch._C._distributed_rpc import PyRRef
|
||||
from torch.futures import Future
|
||||
|
||||
from colossalai.pipeline.pipeline_process_group import ppg
|
||||
from colossalai.pipeline.rpc._pipeline_base import Phase, PipelineEngineBase, UniqueKey, WorkerBase, WorkItem
|
||||
|
||||
# Implementation of different Pipeline schedule
|
||||
# <strategy>Worker defines the worker for each stage
|
||||
# <strategy>PipelineEngine is the class for use
|
||||
@@ -86,7 +87,7 @@ class OneFOneBWorker(WorkerBase):
|
||||
outstanding_min = actual_stage_num - pp_rank - 1
|
||||
outstanding_max = actual_stage_num - pp_rank
|
||||
self.outstanding_range = (outstanding_min, outstanding_max)
|
||||
elif target_key.microbatch_id == num_microbatches - 1:
|
||||
if target_key.microbatch_id == num_microbatches - 1:
|
||||
self.outstanding_range = (0, 0)
|
||||
|
||||
return target_key
|
||||
|
Reference in New Issue
Block a user