mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-08-12 21:25:53 +00:00
* init
* rename and remove useless func
* basic chunk
* add evoformer
* align evoformer
* add meta
* basic chunk
* basic memory
* finish basic inference memory estimation
* finish memory estimation
* fix bug
* finish memory estimation
* add part of index tracer
* finish basic index tracer
* add doc string
* add doc str
* polish code
* polish code
* update active log
* polish code
* add possible region search
* finish region search loop
* finish chunk define
* support new op
* rename index tracer
* finishi codegen on msa
* redesign index tracer, add source and change compute
* pass outproduct mean
* code format
* code format
* work with outerproductmean and msa
* code style
* code style
* code style
* code style
* change threshold
* support check_index_duplicate
* support index dupilictae and update loop
* support output
* update memory estimate
* optimise search
* fix layernorm
* move flow tracer
* refactor flow tracer
* format code
* refactor flow search
* code style
* adapt codegen to prepose node
* code style
* remove abandoned function
* remove flow tracer
* code style
* code style
* reorder nodes
* finish node reorder
* update run
* code style
* add chunk select class
* add chunk select
* code style
* add chunksize in emit, fix bug in reassgin shape
* code style
* turn off print mem
* add evoformer openfold init
* init openfold
* add benchmark
* add print
* code style
* code style
* init openfold
* update openfold
* align openfold
* use max_mem to control stratge
* update source add
* add reorder in mem estimator
* improve reorder efficeincy
* support ones_like, add prompt if fit mode search fail
* fix a bug in ones like, dont gen chunk if dim size is 1
* fix bug again
* update min memory stratege, reduce mem usage by 30%
* last version of benchmark
* refactor structure
* restruct dir
* update test
* rename
* take apart chunk code gen
* close mem and code print
* code format
* rename ambiguous variable
* seperate flow tracer
* seperate input node dim search
* seperate prepose_nodes
* seperate non chunk input
* seperate reorder
* rename
* ad reorder graph
* seperate trace flow
* code style
* code style
* fix typo
* set benchmark
* rename test
* update codegen test
* Fix state_dict key missing issue of the ZeroDDP (#2363)
* Fix state_dict output for ZeroDDP duplicated parameters
* Rewrite state_dict based on get_static_torch_model
* Modify get_static_torch_model to be compatible with the lower version (ZeroDDP)
* update codegen test
* update codegen test
* add chunk search test
* code style
* add available
* [hotfix] fix gpt gemini example (#2404)
* [hotfix] fix gpt gemini example
* [example] add new assertions
* remove autochunk_available
* [workflow] added nightly release to pypi (#2403)
* add comments
* code style
* add doc for search chunk
* [doc] updated readme regarding pypi installation (#2406)
* add doc for search
* [doc] updated kernel-related optimisers' docstring (#2385)
* [doc] updated kernel-related optimisers' docstring
* polish doc
* rename trace_index to trace_indice
* rename function from index to indice
* rename
* rename in doc
* [polish] polish code for get_static_torch_model (#2405)
* [gemini] polish code
* [testing] remove code
* [gemini] make more robust
* rename
* rename
* remove useless function
* [worfklow] added coverage test (#2399)
* [worfklow] added coverage test
* polish code
* polish code
* polish code
* polish code
* polish code
* polish code
* polish code
* polish code
* add doc for trace indice
* [docker] updated Dockerfile and release workflow (#2410)
* add doc
* update doc
* add available
* change imports
* add test in import
* [workflow] refactored the example check workflow (#2411)
* [workflow] refactored the example check workflow
* polish code
* polish code
* polish code
* polish code
* polish code
* polish code
* polish code
* polish code
* polish code
* polish code
* polish code
* Update parallel_context.py (#2408)
* [hotfix] add DISTPAN argument for benchmark (#2412)
* change the benchmark config file
* change config
* revert config file
* rename distpan to distplan
* [workflow] added precommit check for code consistency (#2401)
* [workflow] added precommit check for code consistency
* polish code
* polish code
* polish code
* polish code
* polish code
* polish code
* polish code
* adapt new fx
* [workflow] added translation for non-english comments (#2414)
* [setup] refactored setup.py for dependency graph (#2413)
* change import
* update doc
* [workflow] auto comment if precommit check fails (#2417)
* [hotfix] add norm clearing for the overflow step (#2416)
* [examples] adding tflops to PaLM (#2365)
* [workflow]auto comment with test coverage report (#2419)
* [workflow]auto comment with test coverage report
* polish code
* polish yaml
* [doc] added documentation for CI/CD (#2420)
* [doc] added documentation for CI/CD
* polish markdown
* polish markdown
* polish markdown
* [example] removed duplicated stable diffusion example (#2424)
* [zero] add inference mode and its unit test (#2418)
* [workflow] report test coverage even if below threshold (#2431)
* [example] improved the clarity yof the example readme (#2427)
* [example] improved the clarity yof the example readme
* polish workflow
* polish workflow
* polish workflow
* polish workflow
* polish workflow
* polish workflow
* [ddp] add is_ddp_ignored (#2434)
[ddp] rename to is_ddp_ignored
* [workflow] make test coverage report collapsable (#2436)
* [autoparallel] add shard option (#2423)
* [fx] allow native ckpt trace and codegen. (#2438)
* [cli] provided more details if colossalai run fail (#2442)
* [autoparallel] integrate device mesh initialization into autoparallelize (#2393)
* [autoparallel] integrate device mesh initialization into autoparallelize
* add megatron solution
* update gpt autoparallel examples with latest api
* adapt beta value to fit the current computation cost
* [zero] fix state_dict and load_state_dict for ddp ignored parameters (#2443)
* [ddp] add is_ddp_ignored
[ddp] rename to is_ddp_ignored
* [zero] fix state_dict and load_state_dict
* fix bugs
* [zero] update unit test for ZeroDDP
* [example] updated the hybrid parallel tutorial (#2444)
* [example] updated the hybrid parallel tutorial
* polish code
* [zero] add warning for ignored parameters (#2446)
* [example] updated large-batch optimizer tutorial (#2448)
* [example] updated large-batch optimizer tutorial
* polish code
* polish code
* [example] fixed seed error in train_dreambooth_colossalai.py (#2445)
* [workflow] fixed the on-merge condition check (#2452)
* [workflow] automated the compatiblity test (#2453)
* [workflow] automated the compatiblity test
* polish code
* [autoparallel] update binary elementwise handler (#2451)
* [autoparallel] update binary elementwise handler
* polish
* [workflow] automated bdist wheel build (#2459)
* [workflow] automated bdist wheel build
* polish workflow
* polish readme
* polish readme
* Fix False warning in initialize.py (#2456)
* Update initialize.py
* pre-commit run check
* [examples] update autoparallel tutorial demo (#2449)
* [examples] update autoparallel tutorial demo
* add test_ci.sh
* polish
* add conda yaml
* [cli] fixed hostname mismatch error (#2465)
* [example] integrate autoparallel demo with CI (#2466)
* [example] integrate autoparallel demo with CI
* polish code
* polish code
* polish code
* polish code
* [zero] low level optim supports ProcessGroup (#2464)
* [example] update vit ci script (#2469)
* [example] update vit ci script
* [example] update requirements
* [example] update requirements
* [example] integrate seq-parallel tutorial with CI (#2463)
* [zero] polish low level optimizer (#2473)
* polish pp middleware (#2476)
Co-authored-by: Ziyue Jiang <ziyue.jiang@gmail.com>
* [example] update gpt gemini example ci test (#2477)
* [zero] add unit test for low-level zero init (#2474)
* [workflow] fixed the skip condition of example weekly check workflow (#2481)
* [example] stable diffusion add roadmap
* add dummy test_ci.sh
* [example] stable diffusion add roadmap (#2482)
* [CI] add test_ci.sh for palm, opt and gpt (#2475)
* polish code
* [example] titans for gpt
* polish readme
* remove license
* polish code
* update readme
* [example] titans for gpt (#2484)
* [autoparallel] support origin activation ckpt on autoprallel system (#2468)
* [autochunk] support evoformer tracer (#2485)
support full evoformer tracer, which is a main module of alphafold. previously we just support a simplifed version of it.
1. support some evoformer's op in fx
2. support evoformer test
3. add repos for test code
* [example] fix requirements (#2488)
* [zero] add unit testings for hybrid parallelism (#2486)
* [hotfix] gpt example titans bug #2493
* polish code and fix dataloader bugs
* [hotfix] gpt example titans bug #2493 (#2494)
* [fx] allow control of ckpt_codegen init (#2498)
* [fx] allow control of ckpt_codegen init
Currently in ColoGraphModule, ActivationCheckpointCodeGen will be set automatically in __init__. But other codegen can't be set if so.
So I add an arg to control whether to set ActivationCheckpointCodeGen in __init__.
* code style
* [example] dreambooth example
* add test_ci.sh to dreambooth
* [autochunk] support autochunk on evoformer (#2497)
* Revert "Update parallel_context.py (#2408)"
This reverts commit 7d5640b9db
.
* add avg partition (#2483)
Co-authored-by: Ziyue Jiang <ziyue.jiang@gmail.com>
* [auto-chunk] support extramsa (#3) (#2504)
* [utils] lazy init. (#2148)
* [utils] lazy init.
* [utils] remove description.
* [utils] complete.
* [utils] finalize.
* [utils] fix names.
* [autochunk] support parsing blocks (#2506)
* [zero] add strict ddp mode (#2508)
* [zero] add strict ddp mode
* [polish] add comments for strict ddp mode
* [zero] fix test error
* [doc] update opt and tutorial links (#2509)
* [workflow] fixed changed file detection (#2515)
Co-authored-by: oahzxl <xuanlei.zhao@gmail.com>
Co-authored-by: eric8607242 <e0928021388@gmail.com>
Co-authored-by: HELSON <c2h214748@gmail.com>
Co-authored-by: Frank Lee <somerlee.9@gmail.com>
Co-authored-by: Haofan Wang <haofanwang.ai@gmail.com>
Co-authored-by: Jiarui Fang <fangjiarui123@gmail.com>
Co-authored-by: ZijianYY <119492445+ZijianYY@users.noreply.github.com>
Co-authored-by: YuliangLiu0306 <72588413+YuliangLiu0306@users.noreply.github.com>
Co-authored-by: Super Daniel <78588128+super-dainiu@users.noreply.github.com>
Co-authored-by: ver217 <lhx0217@gmail.com>
Co-authored-by: Ziyue Jiang <ziyue.jiang97@gmail.com>
Co-authored-by: Ziyue Jiang <ziyue.jiang@gmail.com>
Co-authored-by: oahzxl <43881818+oahzxl@users.noreply.github.com>
Co-authored-by: binmakeswell <binmakeswell@gmail.com>
Co-authored-by: Fazzie-Maqianli <55798671+Fazziekey@users.noreply.github.com>
Co-authored-by: アマデウス <kurisusnowdeng@users.noreply.github.com>
347 lines
14 KiB
Python
347 lines
14 KiB
Python
import threading
|
|
from typing import Callable, Dict, List
|
|
|
|
import torch
|
|
import torch.distributed as dist
|
|
from torch._C._distributed_rpc import PyRRef
|
|
from torch.futures import Future
|
|
|
|
from colossalai.pipeline.pipeline_process_group import ppg
|
|
from colossalai.pipeline.rpc._pipeline_base import Phase, PipelineEngineBase, UniqueKey, WorkerBase, WorkItem
|
|
|
|
# Implementation of different Pipeline schedule
|
|
# <strategy>Worker defines the worker for each stage
|
|
# <strategy>PipelineEngine is the class for use
|
|
|
|
|
|
class FillDrainWorker(WorkerBase):
|
|
|
|
def _get_work_item_key(self) -> UniqueKey:
|
|
# execute backward first (if backward phase in work_list)
|
|
num_microbatches = self.num_microbatches
|
|
|
|
if self.forward_times < num_microbatches:
|
|
target_phase = Phase.FORWARD
|
|
target_microbatch_id = self.forward_times
|
|
else:
|
|
target_phase = Phase.BACKWARD
|
|
target_microbatch_id = self.backward_times
|
|
|
|
target_key = UniqueKey(target_microbatch_id, target_phase)
|
|
|
|
return target_key
|
|
|
|
|
|
class FillDrainPipelineEngine(PipelineEngineBase):
|
|
|
|
def __init__(self,
|
|
partition_fn: Callable,
|
|
stage_num: int,
|
|
num_microbatches: int,
|
|
device: str,
|
|
chunk: int = 1,
|
|
criterion: Callable = None,
|
|
metric: Callable = None,
|
|
checkpoint: bool = False,
|
|
data_process_func: Callable = None) -> None:
|
|
|
|
if chunk > 1:
|
|
assert num_microbatches % stage_num == 0, \
|
|
"if you use interleaving strategy, make sure 'num_microbatches' is a multiple of stage_num!"
|
|
use_1F1B = False
|
|
|
|
super().__init__(FillDrainWorker, partition_fn, stage_num, num_microbatches, device, use_1F1B, chunk, criterion,
|
|
metric, checkpoint, data_process_func)
|
|
|
|
|
|
class OneFOneBWorker(WorkerBase):
|
|
|
|
def _get_work_item_key(self) -> UniqueKey:
|
|
# execute backward first (if backward phase in work_list)
|
|
pp_rank = self.pp_rank
|
|
actual_stage_num = self.actual_stage_num
|
|
num_microbatches = self.num_microbatches
|
|
is_last_stage = pp_rank == actual_stage_num - 1
|
|
|
|
if self.outstanding <= self.outstanding_range[0]:
|
|
target_phase = Phase.FORWARD
|
|
target_microbatch_id = self.forward_times
|
|
elif self.outstanding >= self.outstanding_range[1]:
|
|
target_phase = Phase.BACKWARD
|
|
target_microbatch_id = self.backward_times
|
|
else:
|
|
raise ValueError("outstanding_range[1] - outstanding_range[0] must be in [0, 1]")
|
|
|
|
target_key = UniqueKey(target_microbatch_id, target_phase)
|
|
|
|
# change outstanding_range at:
|
|
# 1. forward times reach actual_stage_num, this is the end of continuous forward
|
|
# 2. forward times reach num_microbatches, this is the end of 1F1B mode
|
|
if not is_last_stage and \
|
|
target_key.phase == Phase.FORWARD:
|
|
if target_key.microbatch_id == actual_stage_num - 1 and num_microbatches > 2:
|
|
# Why need num_microbatches > 2 ? Because there is no steady stage when num_microbatches <= 2
|
|
outstanding_min = actual_stage_num - pp_rank - 1
|
|
outstanding_max = actual_stage_num - pp_rank
|
|
self.outstanding_range = (outstanding_min, outstanding_max)
|
|
if target_key.microbatch_id == num_microbatches - 1:
|
|
self.outstanding_range = (0, 0)
|
|
|
|
return target_key
|
|
|
|
|
|
class OneFOneBPipelineEngine(PipelineEngineBase):
|
|
|
|
def __init__(self,
|
|
partition_fn: Callable,
|
|
stage_num: int,
|
|
num_microbatches: int,
|
|
device: str,
|
|
chunk: int = 1,
|
|
criterion: Callable = None,
|
|
metric: Callable = None,
|
|
checkpoint: bool = False,
|
|
data_process_func: Callable = None) -> None:
|
|
|
|
if chunk > 1:
|
|
assert num_microbatches % stage_num == 0, \
|
|
"if you use interleaving strategy, make sure 'num_microbatches' is a multiple of stage_num!"
|
|
# assert num_microbatches > stage_num * chunk, "num_microbatches must be greater than stage_num * chunk"
|
|
use_1F1B = True
|
|
|
|
super().__init__(OneFOneBWorker, partition_fn, stage_num, num_microbatches, device, use_1F1B, chunk, criterion,
|
|
metric, checkpoint, data_process_func)
|
|
|
|
|
|
class ChimeraWorker(WorkerBase):
|
|
|
|
def _get_producer_consumer(self) -> None:
|
|
rank = self.pp_rank
|
|
min_pp_rank = (rank // self.actual_stage_num) * self.actual_stage_num
|
|
max_pp_rank = min_pp_rank + self.actual_stage_num - 1
|
|
|
|
assert self.producer_stage_ids is None, f"all the producers of rank {rank} has been subscribed"
|
|
assert self.consumer_stage_ids is None, f"all the consumers of rank {rank} has been subscribed"
|
|
|
|
# should be aranged in order, the order of the input of current forward
|
|
self.producer_stage_ids = []
|
|
self.consumer_stage_ids = []
|
|
|
|
# Just for demo
|
|
prev_rank = rank - 1
|
|
next_rank = rank + 1
|
|
if prev_rank >= min_pp_rank:
|
|
self.producer_stage_ids.append(prev_rank)
|
|
if next_rank <= max_pp_rank:
|
|
self.consumer_stage_ids.append(next_rank)
|
|
|
|
def _get_work_item_key(self) -> UniqueKey:
|
|
pp_rank = self.pp_rank
|
|
stage_num = self.actual_stage_num
|
|
real_microbatch_num = self.num_microbatches // 2
|
|
|
|
forward_block_size = 1 if self.num_microbatches < stage_num else self.num_microbatches // stage_num
|
|
forward_block_num = self.forward_times // forward_block_size
|
|
|
|
if self.forward_times >= real_microbatch_num or \
|
|
((pp_rank + 1) % stage_num == 0 and forward_block_num > self.backward_times):
|
|
target_phase = Phase.BACKWARD
|
|
target_microbatch_id = self.backward_times
|
|
else: # others
|
|
target_phase = Phase.FORWARD
|
|
target_microbatch_id = self.forward_times
|
|
|
|
# In up pipeline, microbatch_id to consume is 0, 2, 4 (2n)
|
|
# In down pipeline, microbatch_id to consume is 1, 3, 5 (2n + 1)
|
|
real_target_microbatch_id = target_microbatch_id * 2
|
|
if pp_rank >= stage_num:
|
|
real_target_microbatch_id += 1
|
|
target_key = UniqueKey(real_target_microbatch_id, target_phase)
|
|
|
|
with self.work_list_condition_lock:
|
|
self.work_list_condition_lock.wait_for(lambda: target_key in self.work_list)
|
|
return target_key
|
|
|
|
def _initialize_partition(self):
|
|
# In order to ensure the down pipeline share the same parameter
|
|
# with the up pipeline, partition of down partition will be copied
|
|
# from corresponding up stage
|
|
pp_rank = self.pp_rank
|
|
stage_num = self.actual_stage_num
|
|
device = self.device
|
|
if pp_rank < stage_num:
|
|
super()._initialize_partition()
|
|
else:
|
|
# if it is down pipeline, create partition by origin method
|
|
co_up_pp_worker_rref = self.pp_rank_to_worker_rref[pp_rank - stage_num]
|
|
# get the coresponding model state dict and wait for its init
|
|
state_dict = co_up_pp_worker_rref.rpc_sync().get_partition_state_dict()
|
|
super()._initialize_partition()
|
|
self.module_partition.load_state_dict(state_dict)
|
|
|
|
# init group for chimera in ppg
|
|
ppg.get_chimera_all_reduce_group(pp_rank)
|
|
|
|
# lock for step sync
|
|
self.step_sync_lock = threading.Lock()
|
|
self.step_sync_lock.acquire()
|
|
|
|
self.have_grad_lock = threading.Lock()
|
|
self.have_grad_lock.acquire()
|
|
|
|
def _get_lock_gradient(self):
|
|
self.have_grad_lock.acquire()
|
|
grads = self.get_parameter_gradients()
|
|
self.step_sync_lock.release()
|
|
return grads
|
|
|
|
def is_first_stage(self):
|
|
return (self.pp_rank % self.actual_stage_num) == 0
|
|
|
|
def is_last_stage(self):
|
|
return (self.pp_rank % self.actual_stage_num) == self.actual_stage_num - 1
|
|
|
|
def _is_last_step(self, work_item: WorkItem) -> bool:
|
|
if work_item.forward_only:
|
|
last_phase = Phase.FORWARD
|
|
else:
|
|
last_phase = Phase.BACKWARD
|
|
is_last_phase = work_item.phase == last_phase
|
|
last_microbatch_id = self.num_microbatches - 1
|
|
if self.pp_rank < self.actual_stage_num:
|
|
last_microbatch_id -= 1
|
|
is_last_microbatch = work_item.microbatch_id == last_microbatch_id
|
|
return is_last_phase and is_last_microbatch
|
|
|
|
def _get_step_order(self) -> List[int]:
|
|
# TODO : If you want to extend it to multi head chimera, overwrite here
|
|
stage_num = self.actual_stage_num
|
|
pp_rank = self.pp_rank
|
|
# pp_rank in the same device
|
|
local_device_pp_ranks = [pp_rank, stage_num * 2 - pp_rank - 1]
|
|
local_device_pp_ranks.sort(reverse=min(local_device_pp_ranks) < stage_num // 2)
|
|
return local_device_pp_ranks
|
|
|
|
def _hook_before_step(self):
|
|
self.have_grad_lock.release()
|
|
pp_rank = self.pp_rank
|
|
stage_num = self.actual_stage_num
|
|
co_pp_rank = (pp_rank + stage_num) % (2 * stage_num)
|
|
|
|
# if currrent pp_rank is not the first to do step
|
|
# wait its previous pp_rank finish step
|
|
grads = self.get_parameter_gradients()
|
|
|
|
# send
|
|
co_worker = self.pp_rank_to_worker_rref[co_pp_rank]
|
|
co_grads = co_worker.rpc_sync()._get_lock_gradient()
|
|
# sync
|
|
self.step_sync_lock.acquire()
|
|
for i in range(len(grads)):
|
|
grads[i] += co_grads[i]
|
|
|
|
|
|
class ChimeraPipelineEngine(PipelineEngineBase):
|
|
|
|
def __init__(self,
|
|
partition_fn: Callable,
|
|
stage_num: int,
|
|
num_microbatches: int,
|
|
device: str,
|
|
criterion: Callable = None,
|
|
metric: Callable = None,
|
|
checkpoint: bool = False,
|
|
data_process_func: Callable = None) -> None:
|
|
|
|
assert num_microbatches % stage_num == 0, \
|
|
"In Chimera, num_microbatches must be the multiply of stage_num!"
|
|
use_1F1B = False
|
|
chunk = 1
|
|
|
|
super().__init__(ChimeraWorker, partition_fn, stage_num, num_microbatches, device, use_1F1B, chunk, criterion,
|
|
metric, checkpoint, data_process_func)
|
|
|
|
def _consume_constraint(self, microbatch_id: int, forward_only: bool, input_pp_ranks: List[int],
|
|
output_pp_ranks: List[int], ret_future):
|
|
pass
|
|
|
|
def _create_pp_rank_to_rpc_worker_id(self) -> None:
|
|
stage_num = self.stage_num
|
|
self.pp_rank_to_rpc_worker_id = [0] * (stage_num * 2)
|
|
for pp_rank in range(stage_num):
|
|
self.pp_rank_to_rpc_worker_id[pp_rank] = pp_rank
|
|
self.pp_rank_to_rpc_worker_id[pp_rank + stage_num] = stage_num - pp_rank - 1
|
|
|
|
def _create_pp_rank_to_module_partition_id(self) -> None:
|
|
stage_num = self.stage_num
|
|
self.pp_rank_to_module_partition_id = [0] * (stage_num * 2)
|
|
for pp_rank in range(stage_num):
|
|
self.pp_rank_to_module_partition_id[pp_rank] = pp_rank
|
|
self.pp_rank_to_module_partition_id[pp_rank + stage_num] = pp_rank
|
|
|
|
def _create_ret_future(self, output_pp_ranks: List[int]) -> Dict[int, List[Future]]:
|
|
num_microbatches = self.num_microbatches
|
|
stage_num = self.stage_num
|
|
up_ret_future = {pp_rank: [None] * num_microbatches for pp_rank in output_pp_ranks}
|
|
down_ret_future = {pp_rank + stage_num: [None] * num_microbatches for pp_rank in output_pp_ranks}
|
|
# merge up and down
|
|
return {**up_ret_future, **down_ret_future}
|
|
|
|
def _set_input(self, input_pp_ranks: List[int], microbatch_id: int, microbatch, forward_only: bool):
|
|
# offset is 0 for all the ranks in up pipeline
|
|
# offset is stage_num for all the ranks in down pipeline
|
|
offset = (microbatch_id % 2) * self.stage_num
|
|
for pp_rank in input_pp_ranks:
|
|
worker_rref = self.pp_rank_to_worker_rref[pp_rank + offset]
|
|
worker_rref.remote().set_input(microbatch_id, microbatch, forward_only)
|
|
|
|
def _set_labels(self, output_pp_ranks: List[int], microbatch_id: int, microlabels):
|
|
# offset is 0 for all the ranks in up pipeline
|
|
# offset is stage_num for all the ranks in down pipeline
|
|
offset = (microbatch_id % 2) * self.stage_num
|
|
for pp_rank in output_pp_ranks:
|
|
worker_rref = self.pp_rank_to_worker_rref[pp_rank + offset]
|
|
worker_rref.remote().set_labels(microbatch_id, microlabels)
|
|
|
|
def _subscribe_forward(self, microbatch_id: int, output_pp_ranks: List[int], ret_future: Dict[int, List[Future]]):
|
|
key = UniqueKey(microbatch_id, Phase.FORWARD)
|
|
offset = (microbatch_id % 2) * self.stage_num
|
|
for pp_rank in output_pp_ranks:
|
|
worker_rref = self.pp_rank_to_worker_rref[pp_rank + offset]
|
|
ret_future[pp_rank + offset][microbatch_id] = worker_rref.rpc_async().get_output_by_key(key)
|
|
|
|
def _ensure_backward(self, forward_only: bool, input_pp_ranks: List[int]):
|
|
stage_num = self.stage_num
|
|
num_microbatches = self.num_microbatches
|
|
if not forward_only:
|
|
for pp_rank in input_pp_ranks:
|
|
up_last_microbatch_id = num_microbatches - 2
|
|
down_last_microbatch_id = num_microbatches - 1
|
|
|
|
up_worker_rref = self.pp_rank_to_worker_rref[pp_rank]
|
|
down_worker_rref = self.pp_rank_to_worker_rref[pp_rank + stage_num]
|
|
|
|
up_key = UniqueKey(up_last_microbatch_id, Phase.BACKWARD)
|
|
down_key = UniqueKey(down_last_microbatch_id, Phase.BACKWARD)
|
|
up_worker_rref.rpc_sync().get_output_by_key(up_key)
|
|
down_worker_rref.rpc_sync().get_output_by_key(down_key)
|
|
|
|
def _collect_forward_result(self, output_pp_ranks: List[int], ret_future: Dict[PyRRef, List[Future]]):
|
|
"""Logic of collection of forward in Chimera.
|
|
Currently, only one input one output model is supported
|
|
"""
|
|
stage_num = self.stage_num
|
|
forward_result = []
|
|
for pp_rank in output_pp_ranks:
|
|
worker_forward_result = [None] * self.num_microbatches
|
|
for microbatch_id in range(self.num_microbatches):
|
|
offset = (microbatch_id % 2) * stage_num
|
|
ret = ret_future[pp_rank + offset][microbatch_id].wait()
|
|
ret = [ret] if isinstance(ret, torch.Tensor) else ret
|
|
worker_forward_result[microbatch_id] = ret
|
|
|
|
worker_forward_result = list(zip(*worker_forward_result))
|
|
forward_result.extend(worker_forward_result)
|
|
|
|
return forward_result
|