ColossalAI/colossalai/pipeline/rpc/_pipeline_base.py
Boyuan Yao 7a58dc5ad2
Update metainfo patch branch (#2517)
* init

* rename and remove useless func

* basic chunk

* add evoformer

* align evoformer

* add meta

* basic chunk

* basic memory

* finish basic inference memory estimation

* finish memory estimation

* fix bug

* finish memory estimation

* add part of index tracer

* finish basic index tracer

* add doc string

* add doc str

* polish code

* polish code

* update active log

* polish code

* add possible region search

* finish region search loop

* finish chunk define

* support new op

* rename index tracer

* finishi codegen on msa

* redesign index tracer, add source and change compute

* pass outproduct mean

* code format

* code format

* work with outerproductmean and msa

* code style

* code style

* code style

* code style

* change threshold

* support check_index_duplicate

* support index dupilictae and update loop

* support output

* update memory estimate

* optimise search

* fix layernorm

* move flow tracer

* refactor flow tracer

* format code

* refactor flow search

* code style

* adapt codegen to prepose node

* code style

* remove abandoned function

* remove flow tracer

* code style

* code style

* reorder nodes

* finish node reorder

* update run

* code style

* add chunk select class

* add chunk select

* code style

* add chunksize in emit, fix bug in reassgin shape

* code style

* turn off print mem

* add evoformer openfold init

* init openfold

* add benchmark

* add print

* code style

* code style

* init openfold

* update openfold

* align openfold

* use max_mem to control stratge

* update source add

* add reorder in mem estimator

* improve reorder efficeincy

* support ones_like, add prompt if fit mode search fail

* fix a bug in ones like, dont gen chunk if dim size is 1

* fix bug again

* update min memory stratege, reduce mem usage by 30%

* last version of benchmark

* refactor structure

* restruct dir

* update test

* rename

* take apart chunk code gen

* close mem and code print

* code format

* rename ambiguous variable

* seperate flow tracer

* seperate input node dim search

* seperate prepose_nodes

* seperate non chunk input

* seperate reorder

* rename

* ad reorder graph

* seperate trace flow

* code style

* code style

* fix typo

* set benchmark

* rename test

* update codegen test

* Fix state_dict key missing issue of the ZeroDDP (#2363)

* Fix state_dict output for ZeroDDP duplicated parameters

* Rewrite state_dict based on get_static_torch_model

* Modify get_static_torch_model to be compatible with the lower version (ZeroDDP)

* update codegen test

* update codegen test

* add chunk search test

* code style

* add available

* [hotfix] fix gpt gemini example (#2404)

* [hotfix] fix gpt gemini example

* [example] add new assertions

* remove autochunk_available

* [workflow] added nightly release to pypi (#2403)

* add comments

* code style

* add doc for search chunk

* [doc] updated readme regarding pypi installation (#2406)

* add doc for search

* [doc] updated kernel-related optimisers' docstring (#2385)

* [doc] updated kernel-related optimisers' docstring

* polish doc

* rename trace_index to trace_indice

* rename function from index to indice

* rename

* rename in doc

* [polish] polish code for get_static_torch_model (#2405)

* [gemini] polish code

* [testing] remove code

* [gemini] make more robust

* rename

* rename

* remove useless function

* [worfklow] added coverage test (#2399)

* [worfklow] added coverage test

* polish code

* polish code

* polish code

* polish code

* polish code

* polish code

* polish code

* polish code

* add doc for trace indice

* [docker] updated Dockerfile and release workflow (#2410)

* add doc

* update doc

* add available

* change imports

* add test in import

* [workflow] refactored the example check workflow (#2411)

* [workflow] refactored the example check workflow

* polish code

* polish code

* polish code

* polish code

* polish code

* polish code

* polish code

* polish code

* polish code

* polish code

* polish code

* Update parallel_context.py (#2408)

* [hotfix] add DISTPAN argument for benchmark (#2412)

* change the benchmark config file

* change config

* revert config file

* rename distpan to distplan

* [workflow] added precommit check for code consistency (#2401)

* [workflow] added precommit check for code consistency

* polish code

* polish code

* polish code

* polish code

* polish code

* polish code

* polish code

* adapt new fx

* [workflow] added translation for non-english comments (#2414)

* [setup] refactored setup.py for dependency graph (#2413)

* change import

* update doc

* [workflow] auto comment if precommit check fails (#2417)

* [hotfix] add norm clearing for the overflow step (#2416)

* [examples] adding tflops to PaLM (#2365)

* [workflow]auto comment with test coverage report (#2419)

* [workflow]auto comment with test coverage report

* polish code

* polish yaml

* [doc] added documentation for CI/CD (#2420)

* [doc] added documentation for CI/CD

* polish markdown

* polish markdown

* polish markdown

* [example] removed duplicated stable diffusion example (#2424)

* [zero] add inference mode and its unit test (#2418)

* [workflow] report test coverage even if below threshold (#2431)

* [example] improved the clarity yof the example readme (#2427)

* [example] improved the clarity yof the example readme

* polish workflow

* polish workflow

* polish workflow

* polish workflow

* polish workflow

* polish workflow

* [ddp] add is_ddp_ignored (#2434)

[ddp] rename to is_ddp_ignored

* [workflow] make test coverage report collapsable (#2436)

* [autoparallel] add shard option (#2423)

* [fx] allow native ckpt trace and codegen. (#2438)

* [cli] provided more details if colossalai run fail (#2442)

* [autoparallel] integrate device mesh initialization into autoparallelize (#2393)

* [autoparallel] integrate device mesh initialization into autoparallelize

* add megatron solution

* update gpt autoparallel examples with latest api

* adapt beta value to fit the current computation cost

* [zero] fix state_dict and load_state_dict for ddp ignored parameters (#2443)

* [ddp] add is_ddp_ignored

[ddp] rename to is_ddp_ignored

* [zero] fix state_dict and load_state_dict

* fix bugs

* [zero] update unit test for ZeroDDP

* [example] updated the hybrid parallel tutorial (#2444)

* [example] updated the hybrid parallel tutorial

* polish code

* [zero] add warning for ignored parameters (#2446)

* [example] updated large-batch optimizer tutorial (#2448)

* [example] updated large-batch optimizer tutorial

* polish code

* polish code

* [example] fixed seed error in train_dreambooth_colossalai.py (#2445)

* [workflow] fixed the on-merge condition check (#2452)

* [workflow] automated the compatiblity test (#2453)

* [workflow] automated the compatiblity test

* polish code

* [autoparallel] update binary elementwise handler (#2451)

* [autoparallel] update binary elementwise handler

* polish

* [workflow] automated bdist wheel build (#2459)

* [workflow] automated bdist wheel build

* polish workflow

* polish readme

* polish readme

* Fix False warning in initialize.py (#2456)

* Update initialize.py

* pre-commit run check

* [examples] update autoparallel tutorial demo (#2449)

* [examples] update autoparallel tutorial demo

* add test_ci.sh

* polish

* add conda yaml

* [cli] fixed hostname mismatch error (#2465)

* [example] integrate autoparallel demo with CI (#2466)

* [example] integrate autoparallel demo with CI

* polish code

* polish code

* polish code

* polish code

* [zero] low level optim supports ProcessGroup (#2464)

* [example] update vit ci script (#2469)

* [example] update vit ci script

* [example] update requirements

* [example] update requirements

* [example] integrate seq-parallel tutorial with CI (#2463)

* [zero] polish low level optimizer (#2473)

* polish pp middleware (#2476)

Co-authored-by: Ziyue Jiang <ziyue.jiang@gmail.com>

* [example] update gpt gemini example ci test (#2477)

* [zero] add unit test for low-level zero init (#2474)

* [workflow] fixed the skip condition of  example weekly check workflow (#2481)

* [example] stable diffusion add roadmap

* add dummy test_ci.sh

* [example] stable diffusion add roadmap (#2482)

* [CI] add test_ci.sh for palm, opt and gpt (#2475)

* polish code

* [example] titans for gpt

* polish readme

* remove license

* polish code

* update readme

* [example] titans for gpt (#2484)

* [autoparallel] support origin activation ckpt on autoprallel system (#2468)

* [autochunk] support evoformer tracer (#2485)

support full evoformer tracer, which is a main module of alphafold. previously we just support a simplifed version of it.
1. support some evoformer's op in fx
2. support evoformer test
3. add repos for test code

* [example] fix requirements (#2488)

* [zero] add unit testings for hybrid parallelism  (#2486)

* [hotfix] gpt example titans bug #2493

* polish code and fix dataloader bugs

* [hotfix] gpt example titans bug #2493 (#2494)

* [fx] allow control of ckpt_codegen init (#2498)

* [fx] allow control of ckpt_codegen init

Currently in ColoGraphModule, ActivationCheckpointCodeGen will be set automatically in __init__. But other codegen can't be set if so. 
So I add an arg to control whether to set ActivationCheckpointCodeGen in __init__.

* code style

* [example] dreambooth example

* add test_ci.sh to dreambooth

* [autochunk] support autochunk on evoformer (#2497)

* Revert "Update parallel_context.py (#2408)"

This reverts commit 7d5640b9db.

* add avg partition (#2483)

Co-authored-by: Ziyue Jiang <ziyue.jiang@gmail.com>

* [auto-chunk] support extramsa (#3) (#2504)

* [utils] lazy init. (#2148)

* [utils] lazy init.

* [utils] remove description.

* [utils] complete.

* [utils] finalize.

* [utils] fix names.

* [autochunk] support parsing blocks (#2506)

* [zero] add strict ddp mode (#2508)

* [zero] add strict ddp mode

* [polish] add comments for strict ddp mode

* [zero] fix test error

* [doc] update opt and tutorial links (#2509)

* [workflow] fixed changed file detection (#2515)

Co-authored-by: oahzxl <xuanlei.zhao@gmail.com>
Co-authored-by: eric8607242 <e0928021388@gmail.com>
Co-authored-by: HELSON <c2h214748@gmail.com>
Co-authored-by: Frank Lee <somerlee.9@gmail.com>
Co-authored-by: Haofan Wang <haofanwang.ai@gmail.com>
Co-authored-by: Jiarui Fang <fangjiarui123@gmail.com>
Co-authored-by: ZijianYY <119492445+ZijianYY@users.noreply.github.com>
Co-authored-by: YuliangLiu0306 <72588413+YuliangLiu0306@users.noreply.github.com>
Co-authored-by: Super Daniel <78588128+super-dainiu@users.noreply.github.com>
Co-authored-by: ver217 <lhx0217@gmail.com>
Co-authored-by: Ziyue Jiang <ziyue.jiang97@gmail.com>
Co-authored-by: Ziyue Jiang <ziyue.jiang@gmail.com>
Co-authored-by: oahzxl <43881818+oahzxl@users.noreply.github.com>
Co-authored-by: binmakeswell <binmakeswell@gmail.com>
Co-authored-by: Fazzie-Maqianli <55798671+Fazziekey@users.noreply.github.com>
Co-authored-by: アマデウス <kurisusnowdeng@users.noreply.github.com>
2023-01-27 09:52:21 +08:00

1309 lines
58 KiB
Python

import inspect
import math
import threading
from abc import ABC, abstractmethod
from enum import Enum
from functools import partial
from typing import Any, Callable, Dict, List, Tuple
import torch
import torch.distributed.rpc as rpc
from torch import autograd, nn, optim
from torch._C._distributed_rpc import PyRRef
from torch.futures import Future
from colossalai.pipeline.middleware import Partition, PartitionInputVal, PartitionOutputVal, Topo
from colossalai.pipeline.pipeline_process_group import ppg
from colossalai.pipeline.rpc.utils import (
get_batch_lengths,
pyobj_map,
pytree_filter,
pytree_map,
split_batch,
tensor_shape_list,
type_detail,
)
class Phase(Enum):
FORWARD = 0
BACKWARD = 1
UPDATE = 2
INPUT = 3
class UniqueKey:
__slots__ = ('microbatch_id', 'phase')
microbatch_id: int
phase: Phase
def __init__(self, microbatch_id, phase) -> None:
self.microbatch_id = microbatch_id
self.phase = phase
def __eq__(self, __o: object) -> bool:
return (self.microbatch_id == __o.microbatch_id) and (self.phase == __o.phase)
def __hash__(self) -> int:
return tuple.__hash__((self.microbatch_id, self.phase))
def __repr__(self) -> str:
return f'Key(microbatch_id={self.microbatch_id}, phase={self.phase})'
class WorkItem:
__slots__ = ('stage_id', 'phase', 'args', 'kwargs', 'output', 'refcount', 'microbatch_id', 'batch_id',
'num_microbatches', 'forward_only')
stage_id: int
phase: Phase
args: Tuple[Any]
kwargs: Dict[str, Any]
output: Future
microbatch_id: int
refcount: int
batch_id: int
num_microbatches: int
forward_only: bool
def __init__(self,
stage_id,
phase,
args,
kwargs,
output,
microbatch_id,
batch_id,
num_microbatches,
forward_only,
refcount=0) -> None:
for attr_name in self.__slots__:
setattr(self, attr_name, locals()[attr_name])
class BackwardCache:
__slots__ = ('checkpoint', 'stage_input_args', 'stage_input_kwargs', 'stage_outputs')
checkpoint: bool
stage_input_args: Tuple[Any]
stage_input_kwargs: Dict[Any, Any]
stage_outputs: Tuple[Any]
def __init__(self,
stage_input_args: Tuple[Any],
stage_input_kwargs: Dict[Any, Any] = None,
stage_outputs: Tuple[Any] = None,
checkpoint: bool = False) -> None:
for arg_name in self.__slots__:
setattr(self, arg_name, locals()[arg_name])
class WorkerBase(ABC):
def __init__(self,
partition_fn: Callable,
partition_args: tuple,
pp_rank: int,
actual_stage_num: int,
num_microbatches: int,
device: str,
criterion: Callable = None,
metric: Callable = None,
checkpoint: bool = False,
data_process_func: Callable = None) -> None:
super().__init__()
self.pp_rank = pp_rank
self.actual_stage_num = actual_stage_num
self.num_microbatches = num_microbatches
self.checkpoint = checkpoint
if data_process_func is not None:
self.data_process_func = partial(data_process_func, pp_rank)
self.device = device
self._initialize_outstanding_range()
# variable and const for context managment
self.outstanding = 0
self.forward_times = 0
self.backward_times = 0
self.reset_key = UniqueKey(0, Phase.FORWARD)
# rref of other workers
self.pp_rank_to_worker_rref: Dict[int, PyRRef] = None
# lock for the list
self._initialize_lock()
# topology info
self.producer_stage_ids: List[int] = None
self.consumer_stage_ids: List[int] = None
# module partitions
self.partition_fn = partition_fn
self.partition_args = partition_args
self.criterion = criterion
self.metric = metric
self.reset = False
# context to maintain loop
self._initialize_context_container()
# main loop
self.main_loop_thread = threading.Thread(target=self._work_loop, name=f'rank_{pp_rank}', daemon=True)
self.main_loop_thread.start()
def _get_future_by_device(self):
return torch.futures.Future(devices=None if self.device in (None, 'cpu') else [self.device])
def _initialize_outstanding_range(self):
outstanding_range = None
if self.pp_rank == self.actual_stage_num - 1:
outstanding_range = (0, 1)
else:
outstanding_range = (self.actual_stage_num, self.actual_stage_num)
self.outstanding_range = outstanding_range
def _initialize_context_container(self):
self.microbatch_id_to_backward_cache: Dict[int, BackwardCache] = dict()
self.microbatch_id_to_labels: Dict[int, Any] = dict()
self.work_list: Dict[UniqueKey, WorkItem] = dict()
self.output_list: Dict[UniqueKey, WorkItem] = dict()
def _initialize_lock(self):
self.partition_condition_lock = threading.Condition(threading.Lock())
self.work_list_condition_lock = threading.Condition(threading.Lock())
self.output_list_condition_lock = threading.Condition(threading.Lock())
self.label_lock = threading.Condition(threading.Lock())
self.reset_condition = threading.Condition(threading.Lock())
def _initialize_partition(self):
partition_fn = self.partition_fn
partition_args = self.partition_args
device = self.device
with self.partition_condition_lock:
self.module_partition: nn.Module = partition_fn(*partition_args).to(device)
self.partition_condition_lock.notify_all()
def _get_output_all(self, key: UniqueKey, ref_use=False, rank=None):
with self.output_list_condition_lock:
self.output_list_condition_lock.wait_for(lambda: key in self.output_list)
output_work_item = self.output_list[key]
output = output_work_item.output
if not ref_use and output_work_item.phase != Phase.INPUT:
self.output_list.pop(key)
if not ref_use and output_work_item.phase != Phase.INPUT:
output_work_item.refcount += 1
refcount = output_work_item.refcount
# lifecycle management for DAG scheduler
if output_work_item.phase == Phase.FORWARD:
lifecycle = len(self.get_consumer_stage_ids())
if self.is_model_output(): # an extra reference for scheduler collecting results
lifecycle += 1
elif output_work_item.phase == Phase.BACKWARD:
lifecycle = len(self.get_producer_stage_ids())
if self.is_model_input() and self._is_last_step(
output_work_item): # an extra reference for ensure_backward
lifecycle += 1
else:
lifecycle = 0
refcount = 0
with self.output_list_condition_lock:
if refcount <= lifecycle:
self.output_list[key] = output_work_item
self.output_list_condition_lock.notify_all()
if isinstance(output, Future):
output = output.wait()
return output
def sync_global_worker_rrefs(self, pp_rank_to_worker_rref: Dict[int, PyRRef]) -> None:
assert self.pp_rank_to_worker_rref is None, f"in rank {self.pp_rank}, worker has sync global workers rrefs"
assert pp_rank_to_worker_rref is not None, "stage_to_workers must be a dict instead of None"
self.pp_rank_to_worker_rref = pp_rank_to_worker_rref
# for some schedule need the other worker's info to initialise partition (like Chimera)
# construction of partition is executed after the registion of pp_rank_to_worker_rref
self._initialize_partition()
# res_use works for lifecycle counter,
# if ref_use is True, lifecycle won't add.
# offset supports get partial output to reduce comm costs.
def get_output_by_key(self, key: UniqueKey, ref_use=False, rank=None, offsets=None) -> Any:
output = self._get_output_all(key, ref_use, rank)
if offsets is None: # get all for non iterable output
return output
else: # get part for iterable output
output = [output[i] for i in offsets]
return output
def get_numels(self) -> int:
numel = sum(param.numel() for param in self.module_partition.parameters())
return numel
def get_parameters(self) -> List[torch.Tensor]:
return [p for p in self.module_partition.parameters()]
def get_parameter_gradients(self) -> List[torch.Tensor]:
return [p.grad for p in self.module_partition.parameters()]
def get_partition(self):
with self.partition_condition_lock:
self.partition_condition_lock.wait_for(lambda: hasattr(self, 'module_partition'))
return self.module_partition
def get_partition_state_dict(self):
with self.partition_condition_lock:
self.partition_condition_lock.wait_for(lambda: hasattr(self, 'module_partition'))
return self.module_partition.state_dict()
def _make_args_kwargs(self, microbatch, merge=False):
if isinstance(microbatch, dict):
if merge:
return list(microbatch.values()), {}
return [], microbatch
elif isinstance(microbatch, torch.Tensor):
return [microbatch], {}
elif isinstance(microbatch, (tuple, list)):
args = []
kwargs = {}
for arg in microbatch:
if isinstance(arg, dict):
kwargs.update(arg)
else:
args.append(arg)
if merge:
arg_lst = args
for arg in kwargs.values():
arg_lst.append(arg)
return arg_lst, {}
return args, kwargs
else:
raise TypeError(f"Input batch can be only dict, list, tuple or tensor, but receive {type(microbatch)}")
# just for first pp_rank
def set_input(self, microbatch_id: int, microbatch: Tuple[Any], forward_only: bool):
key = UniqueKey(microbatch_id, Phase.FORWARD)
output = self._get_future_by_device()
if not self.use_middleware():
# make args and kwargs
args, kwargs = self._make_args_kwargs(microbatch)
work_item = WorkItem(self.pp_rank, Phase.FORWARD, args, kwargs, output, microbatch_id, None,
self.num_microbatches, forward_only)
with self.work_list_condition_lock:
self.work_list[key] = work_item
self.work_list_condition_lock.notify_all()
else:
# make args and kwargs
arg_lst, _ = self._make_args_kwargs(microbatch, merge=True)
# first stage assign correct input into other stages
topo: Topo = self.get_topo()
self_partition_id = self.pp_rank_to_partition_id(self.pp_rank, topo)
input_partition = topo.get_input_partition()
self_input_offsets = input_partition.get_output_offsets(self_partition_id)
recv_input_key = UniqueKey(microbatch_id, Phase.INPUT)
# set input for self rank
self_arg_lst = []
for off in self_input_offsets:
self_arg_lst.append(arg_lst[off])
work_item = WorkItem(self.pp_rank, Phase.FORWARD, self_arg_lst, {}, output, microbatch_id, None,
self.num_microbatches, forward_only)
with self.work_list_condition_lock:
self.work_list[key] = work_item
self.work_list_condition_lock.notify_all()
# put input tensor which other nodes need into output_list as Phase.INPUT
work_item_remote = WorkItem(self.pp_rank, Phase.INPUT, [], {}, arg_lst, microbatch_id, None,
self.num_microbatches, forward_only)
with self.output_list_condition_lock:
self.output_list[recv_input_key] = work_item_remote
self.output_list_condition_lock.notify_all()
# just for last pp_rank
def set_labels(self, microbatch_id: int, microlabels: Any):
with self.label_lock:
self.microbatch_id_to_labels[microbatch_id] = microlabels
self.label_lock.notify_all()
# just for last pp_rank
def _begin_backward(self, microbatch_id: int):
with self.work_list_condition_lock:
assert self.producer_stage_ids is not None
key = UniqueKey(microbatch_id, Phase.BACKWARD)
output = self._get_future_by_device()
grad_wrt_loss = None
work_item = WorkItem(self.pp_rank, Phase.BACKWARD, grad_wrt_loss, {}, output, microbatch_id, None,
self.num_microbatches, False)
self.work_list[key] = work_item
self.work_list_condition_lock.notify_all()
def _subscribe_producer(self, microbatch_id: int, forward_only: bool):
"""
You should call this function asynchronously
"""
stage_id = self.pp_rank
output = self._get_future_by_device()
if not self.use_middleware():
producer_num = len(self.producer_stage_ids)
subscribe_forward_futures: List[Future] = [None] * producer_num
for i in range(producer_num):
producer_stage_id = self.producer_stage_ids[i]
producer_output_key = UniqueKey(microbatch_id, Phase.FORWARD)
producer_worker_rref = self.pp_rank_to_worker_rref[producer_stage_id]
subscribe_forward_futures[i] = producer_worker_rref.rpc_async().get_output_by_key(producer_output_key)
else:
producer_stage_ids = self.get_producer_stage_ids()
producer_num = len(producer_stage_ids)
if self.need_model_input():
producer_num += 1 # for input partition
subscribe_forward_futures: List[Future] = [None] * producer_num
# TODO(jiangziyue) get single value instead of the whole output
if self.need_model_input():
producer_stage_id = 0
producer_output_key = UniqueKey(microbatch_id, Phase.INPUT)
producer_worker_rref = self.pp_rank_to_worker_rref[producer_stage_id]
offsets = self._get_input_offsets_by_index(target_index=0)
subscribe_forward_futures[0] = producer_worker_rref.rpc_async().get_output_by_key(producer_output_key,
rank=self.pp_rank,
offsets=offsets)
for i in range(0, producer_num - 1):
producer_stage_id = producer_stage_ids[i]
producer_output_key = UniqueKey(microbatch_id, Phase.FORWARD)
producer_worker_rref = self.pp_rank_to_worker_rref[producer_stage_id]
target_index = i + 1
offsets = self._get_input_offsets_by_index(target_index=target_index)
if offsets is not None and len(offsets) == 0: # no need to do rpc
subscribe_forward_futures[target_index] = []
else:
subscribe_forward_futures[target_index] = producer_worker_rref.rpc_async().get_output_by_key(
producer_output_key, rank=self.pp_rank, offsets=offsets)
else:
for i in range(producer_num):
producer_stage_id = producer_stage_ids[i]
producer_output_key = UniqueKey(microbatch_id, Phase.FORWARD)
producer_worker_rref = self.pp_rank_to_worker_rref[producer_stage_id]
target_index = i
offsets = self._get_input_offsets_by_index(target_index=target_index)
if offsets is not None and len(offsets) == 0: # no need to do rpc
subscribe_forward_futures[target_index] = []
else:
subscribe_forward_futures[target_index] = producer_worker_rref.rpc_async().get_output_by_key(
producer_output_key, rank=self.pp_rank, offsets=offsets)
work_item_from_producer = WorkItem(stage_id, Phase.FORWARD, subscribe_forward_futures, {}, output,
microbatch_id, None, self.num_microbatches, forward_only)
return work_item_from_producer
# TODO(jiangziyue) Profile the side effect of the lock for lifecycle protection and consider a better one.
def subscribe_producer(self, microbatch_id: int, forward_only: bool):
key = UniqueKey(microbatch_id, Phase.FORWARD)
with self.work_list_condition_lock:
if key not in self.work_list:
# On current PP middleware design for DAG, get_output_by_key used by _subscribe_producer
# can only be executed once for every producer-consumer stage pair, which is necessary
# to count the lifecycle of work_item. So, keeping the _subscribe_producer in the same
# lock of work_item queue operation gurantees the consistency of lifecycle counter.
work_item_from_producer = self._subscribe_producer(microbatch_id, forward_only)
self.work_list[key] = work_item_from_producer
self.work_list_condition_lock.notify_all()
def _subscribe_consumer(self, microbatch_id: int):
"""
You should call this function asynchronously
"""
stage_id = self.pp_rank
output = self._get_future_by_device()
if not self.use_middleware():
consumer_stage_ids = self.consumer_stage_ids
else:
consumer_stage_ids = self.get_consumer_stage_ids()
consumer_num = len(consumer_stage_ids)
subscribe_backward_futures: List[Future] = [None] * consumer_num
for i in range(consumer_num):
consumer_stage_id = consumer_stage_ids[i]
consumer_output_key = UniqueKey(microbatch_id, Phase.BACKWARD)
consumer_worker_rref = self.pp_rank_to_worker_rref[consumer_stage_id]
target_index = i
offsets = self._get_output_offsets_by_index(target_index=target_index)
if offsets is not None and len(offsets) == 0: # no need to do rpc
subscribe_backward_futures[target_index] = []
else:
subscribe_backward_futures[target_index] = consumer_worker_rref.rpc_async().get_output_by_key(
consumer_output_key, rank=self.pp_rank, offsets=offsets)
# flatten args
work_item_from_consumer = WorkItem(stage_id, Phase.BACKWARD, subscribe_backward_futures, {}, output,
microbatch_id, None, self.num_microbatches, False)
return work_item_from_consumer
def subscribe_consumer(self, microbatch_id: int):
key = UniqueKey(microbatch_id, Phase.BACKWARD)
with self.work_list_condition_lock:
if key not in self.work_list:
# On current PP middleware design for DAG, get_output_by_key used by subscribe_consumer
# can only be executed once for every producer-consumer stage pair, which is necessary
# to count the lifecycle of work_item. So, keeping the subscribe_consumer in the same
# lock of work_item queue operation gurantees the consistency of lifecycle counter.
work_item_from_consumer = self._subscribe_consumer(microbatch_id)
self.work_list[key] = work_item_from_consumer
self.work_list_condition_lock.notify_all()
def get_producer_stage_ids(self):
producer_stage_ids = []
rank = self.pp_rank
if not self.use_middleware():
prev_rank = rank - 1
if prev_rank >= 0:
producer_stage_ids.append(prev_rank)
else:
topo: Topo = self.get_topo()
self_partition_id = self.pp_rank_to_partition_id(rank, topo)
self_partition: Partition = topo.get_partition_by_id(self_partition_id)
input_partition_ids = self_partition.get_input_partition_ids()
model_input_partition_id = topo.get_input_partition_id()
for partition_id in input_partition_ids:
# ignore input partition in current implementation.
# it will be specially tackled.
if partition_id != model_input_partition_id:
producer_stage_ids.append(self.partition_id_to_pp_rank(partition_id, topo))
return producer_stage_ids
def get_consumer_stage_ids(self):
consumer_stage_ids = []
rank = self.pp_rank
if not self.use_middleware():
next_rank = rank + 1
if next_rank <= self.actual_stage_num - 1:
consumer_stage_ids.append(next_rank)
else:
topo: Topo = self.get_topo()
self_partition_id = self.pp_rank_to_partition_id(rank, topo)
self_partition: Partition = topo.get_partition_by_id(self_partition_id)
output_partition_ids = self_partition.get_output_partition_ids()
model_output_partition_id = topo.get_output_partition_id()
for partition_id in output_partition_ids:
if model_output_partition_id != partition_id:
consumer_stage_ids.append(self.partition_id_to_pp_rank(partition_id, topo))
return consumer_stage_ids
def _get_producer_consumer(self) -> None:
rank = self.pp_rank
assert self.producer_stage_ids is None, f"all the producers of rank {rank} has been subscribed"
assert self.consumer_stage_ids is None, f"all the consumers of rank {rank} has been subscribed"
# should be aranged in order, the order of the input of current forward
self.producer_stage_ids = self.get_producer_stage_ids()
self.consumer_stage_ids = self.get_consumer_stage_ids()
def pp_rank_to_partition_id(self, pp_rank: int, topo: Topo):
partition_ids = topo.get_mid_partition_ids()
return partition_ids[pp_rank]
def partition_id_to_pp_rank(self, partition_id: int, topo: Topo):
partition_ids = topo.get_mid_partition_ids()
for i, id in enumerate(partition_ids):
if id == partition_id:
return i
def get_topo(self):
with self.partition_condition_lock:
self.partition_condition_lock.wait_for(lambda: hasattr(self, 'module_partition'))
if hasattr(self.module_partition, '_topo'):
return self.module_partition._topo
else:
return None
def use_middleware(self):
topo = self.get_topo()
return topo is not None
def _get_input_offsets_by_index(self, target_index):
res = []
topo: Topo = self.get_topo()
self_partition_id = self.pp_rank_to_partition_id(self.pp_rank, topo)
self_partition: Partition = topo.get_partition_by_id(self_partition_id)
model_input_partition_id = topo.get_input_partition_id()
input_vals = self_partition.get_input_vals()
producer_stage_ids = self.get_producer_stage_ids()
if self.need_model_input():
# 0 for data from input batch
# >= 1 for data from prev stages
base = 1
else:
# data from prev stages
base = 0
for val in input_vals:
val_pos = val.get()
src_partition_id = val_pos.partition_id
src_offset = val_pos.offset
src_index = base
src_partition = topo.get_partition_by_id(src_partition_id)
output_len = len(src_partition.get_output_vals())
# data from not-input partition
if src_partition_id != model_input_partition_id:
src_stage_id = self.partition_id_to_pp_rank(src_partition_id, topo)
src_index = base
for i, stage_id in enumerate(producer_stage_ids):
if stage_id == src_stage_id:
src_index += i
break
else: # data from input partition
src_index = 0
# when output_len = 1, not iterable
if target_index == src_index:
if output_len == 1:
res = None # offset = None to get all outputs
return res
else:
res.append(src_offset)
return res
def _get_output_offsets_by_index(self, target_index):
res = []
topo: Topo = self.get_topo()
self_partition_id = self.pp_rank_to_partition_id(self.pp_rank, topo)
self_partition: Partition = topo.get_partition_by_id(self_partition_id)
output_vals = self_partition.get_output_vals()
consumer_stage_ids = self.get_consumer_stage_ids()
for val_list in output_vals:
# An output may be passed to many down stages.
target = None
for val_pos in val_list.get():
dst_partition_id = val_pos.partition_id
dst_offset = val_pos.offset
dst_partition = topo.get_partition_by_id(dst_partition_id)
input_len = len(dst_partition.get_input_vals())
dst_stage_id = self.partition_id_to_pp_rank(dst_partition_id, topo)
for i, stage_id in enumerate(consumer_stage_ids):
if stage_id == dst_stage_id:
dst_index = i
break
if target_index == dst_index:
if input_len == 1:
res = None # offset = None to get all outputs
return res
else:
res.append(dst_offset)
return res
# TODO(jiangziyue) get single value instead of the whole output
def _get_real_args_kwargs_fwd(self, args_or_kwargs):
if not self.use_middleware():
args_or_kwargs = pytree_map(args_or_kwargs, fn=lambda x: x.wait(), process_types=Future)
if args_or_kwargs is not None:
if isinstance(args_or_kwargs, dict):
pass
else:
flatten_args = []
pytree_map(args_or_kwargs, fn=lambda x: flatten_args.append(x), map_all=True)
args_or_kwargs = flatten_args
else:
args_or_kwargs = pytree_map(args_or_kwargs, fn=lambda x: x.wait(), process_types=Future)
if args_or_kwargs is not None:
if isinstance(args_or_kwargs, dict):
pass
else:
flatten_args = []
if self.is_first_stage():
pytree_map(args_or_kwargs, fn=lambda x: flatten_args.append(x), map_all=True)
else: # get by offset
topo: Topo = self.get_topo()
self_partition_id = self.pp_rank_to_partition_id(self.pp_rank, topo)
self_partition: Partition = topo.get_partition_by_id(self_partition_id)
model_input_partition_id = topo.get_input_partition_id()
input_vals = self_partition.get_input_vals()
producer_stage_ids = self.get_producer_stage_ids()
if self.need_model_input():
# 0 for data from input batch
# >= 1 for data from prev stages
base = 1
else:
# data from prev stages
base = 0
for val in input_vals:
val_pos = val.get()
src_partition_id = val_pos.partition_id
src_offset = val_pos.offset
src_index = base
src_partition = topo.get_partition_by_id(src_partition_id)
output_len = len(src_partition.get_output_vals())
# data from not-input partition
if src_partition_id != model_input_partition_id:
src_stage_id = self.partition_id_to_pp_rank(src_partition_id, topo)
src_index = base
for i, stage_id in enumerate(producer_stage_ids):
if stage_id == src_stage_id:
src_index += i
break
else: # data from input partition
src_index = 0
# when output_len = 1, not iterable
if output_len == 1:
target = args_or_kwargs[src_index]
else:
offsets = self._get_input_offsets_by_index(src_index)
real_offset = offsets.index(src_offset)
target = args_or_kwargs[src_index][real_offset]
flatten_args.append(target)
args_or_kwargs = flatten_args
return args_or_kwargs
# TODO(jiangziyue) get single value instead of the whole output
def _get_real_args_kwargs_bwd(self, args_or_kwargs):
if not self.use_middleware():
args_or_kwargs = pytree_map(args_or_kwargs, fn=lambda x: x.wait(), process_types=Future)
if args_or_kwargs is not None:
if isinstance(args_or_kwargs, dict):
pass
else:
flatten_args = []
pytree_map(args_or_kwargs, fn=lambda x: flatten_args.append(x), map_all=True)
args_or_kwargs = flatten_args
else:
for i, arg in enumerate(args_or_kwargs):
args_or_kwargs[i] = arg.wait()
if args_or_kwargs is not None: # get by offset
flatten_args = []
topo: Topo = self.get_topo()
self_partition_id = self.pp_rank_to_partition_id(self.pp_rank, topo)
self_partition: Partition = topo.get_partition_by_id(self_partition_id)
output_vals = self_partition.get_output_vals()
consumer_stage_ids = self.get_consumer_stage_ids()
for val_list in output_vals:
# An output may be passed to many down stages.
target = None
for val_pos in val_list.get():
dst_partition_id = val_pos.partition_id
dst_offset = val_pos.offset
dst_partition = topo.get_partition_by_id(dst_partition_id)
input_len = len(dst_partition.get_input_vals())
dst_stage_id = self.partition_id_to_pp_rank(dst_partition_id, topo)
for i, stage_id in enumerate(consumer_stage_ids):
if stage_id == dst_stage_id:
dst_index = i
break
if input_len == 1:
part_grad = args_or_kwargs[dst_index]
else:
offsets = self._get_output_offsets_by_index(dst_index)
real_offsets = offsets.index(dst_offset)
part_grad = args_or_kwargs[dst_index][real_offsets]
if target is None:
target = part_grad
elif part_grad is not None:
target += part_grad
else:
continue
flatten_args.append(target)
args_or_kwargs = flatten_args
return args_or_kwargs
@abstractmethod
def _get_work_item_key(self) -> UniqueKey:
"""
this method control the order of the microbatch to consume
"""
def is_first_stage(self):
return self.pp_rank == 0
def is_last_stage(self):
return self.pp_rank == self.actual_stage_num - 1
def need_model_input(self):
need_input = False
topo: Topo = self.get_topo()
self_partition_id = self.pp_rank_to_partition_id(self.pp_rank, topo)
self_partition = topo.get_partition_by_id(self_partition_id)
partition_inputs = self_partition.get_input_partition_ids()
model_input_partition_id = topo.get_input_partition_id()
if model_input_partition_id in partition_inputs:
need_input = True
return not self.is_first_stage() and need_input
def is_model_output(self):
return self.is_last_stage()
def is_model_input(self):
return self.is_first_stage()
def _default_data_process_func(self, args_kwargs):
if self.is_first_stage():
args = args_kwargs[0]
kwargs = args_kwargs[1]
else:
args = args_kwargs
kwargs = {}
return args, kwargs
def _consume_work_item_by_phase(self, work_item: WorkItem):
phase = work_item.phase
args = work_item.args
kwargs = work_item.kwargs
microbatch_id = work_item.microbatch_id
forward_only = work_item.forward_only
data_process_func = getattr(self, 'data_process_func', self._default_data_process_func)
consume_result = None
is_first_stage = self.is_first_stage()
is_last_stage = self.is_last_stage()
if phase == Phase.FORWARD:
# remind its consumer to get data before forward
if not is_last_stage:
for stage_id in self.consumer_stage_ids:
consumer_worker_rref = self.pp_rank_to_worker_rref[stage_id]
consumer_worker_rref.remote().subscribe_producer(microbatch_id, forward_only)
# sustain pipeline context
self.forward_times += 1
if not forward_only:
self.outstanding += 1
# parse and integrate args and kwargs
if is_first_stage:
args = self._get_real_args_kwargs_fwd(args)
kwargs = self._get_real_args_kwargs_fwd(kwargs)
args_kwargs = (args, kwargs)
else:
args_kwargs = self._get_real_args_kwargs_fwd(args)
args_kwargs = pyobj_map(args_kwargs, fn=lambda x: x.to(self.device).detach(),
process_types=torch.Tensor) # torch rpc doesn't support args or rets in GPU
args_kwargs = pyobj_map(args_kwargs, fn=lambda x: self.device,
process_types=torch.device) # change devices from last stage to current device
args, kwargs = data_process_func(args_kwargs)
stage_outputs = None
stage_input_args = args
stage_input_kwargs = kwargs
use_checkpoint = None
if forward_only:
with torch.no_grad():
consume_result = self.module_partition(*args, **kwargs)
if is_last_stage and self.criterion:
with self.label_lock:
self.label_lock.wait_for(lambda: microbatch_id in self.microbatch_id_to_labels)
labels = self.microbatch_id_to_labels.pop(microbatch_id)
loss: torch.Tensor = self.criterion(consume_result, labels)
if self.metric is not None:
metric_result = self.metric(consume_result, labels)
if isinstance(metric_result, torch.Tensor):
metric_result = metric_result.item()
else:
metric_result = None
consume_result = [loss.item(), metric_result]
# last stage doesn't need to do checkpoint, for it will do backward instantly
stage_input_args = None
stage_input_kwargs = None
stage_outputs = consume_result
elif self.checkpoint and not is_last_stage:
with torch.no_grad():
consume_result = self.module_partition(*args, **kwargs)
stage_outputs = consume_result
use_checkpoint = True
else:
consume_result = self.module_partition(*args, **kwargs)
if is_last_stage and self.criterion:
with self.label_lock:
self.label_lock.wait_for(lambda: microbatch_id in self.microbatch_id_to_labels)
labels = self.microbatch_id_to_labels.pop(microbatch_id)
loss: torch.Tensor = self.criterion(consume_result, labels)
if self.metric is not None:
metric_result = self.metric(consume_result, labels)
if isinstance(metric_result, torch.Tensor):
metric_result = metric_result.item()
else:
metric_result = None
consume_result = [loss.item(), metric_result]
else:
loss = consume_result
stage_outputs = loss
use_checkpoint = False
if not forward_only:
self.microbatch_id_to_backward_cache[microbatch_id] = BackwardCache(stage_input_args,
stage_input_kwargs,
stage_outputs,
checkpoint=use_checkpoint)
consume_result = pyobj_map(consume_result, fn=lambda x: x.to('cpu'),
process_types=torch.Tensor) # torch rpc doesn't support args or rets in
# if not forward_only, do the backward
if not forward_only:
if is_last_stage: # if it is the last stage, trigger backward automatic
self._begin_backward(microbatch_id)
elif phase == Phase.BACKWARD:
# remind its producer to get data before backward
if not is_first_stage:
for stage_id in self.producer_stage_ids:
producer_worker_rref = self.pp_rank_to_worker_rref[stage_id]
producer_worker_rref.remote().subscribe_consumer(microbatch_id)
self.backward_times += 1
self.outstanding -= 1
assert microbatch_id in self.microbatch_id_to_backward_cache, f"microbatch_id {microbatch_id} not in backward cache"
backward_cache = self.microbatch_id_to_backward_cache.pop(microbatch_id)
stage_outputs = backward_cache.stage_outputs
stage_input_args = backward_cache.stage_input_args
stage_input_kwargs = backward_cache.stage_input_kwargs
use_checkpoint = backward_cache.checkpoint
if use_checkpoint:
stage_outputs = [self.module_partition(*stage_input_args, **stage_input_kwargs)]
# overlap recompute and future.wait
if not is_last_stage:
grad_tensors = self._get_real_args_kwargs_bwd(args)
else:
grad_tensors = None
# take tensor only (for only tensor can do backward)
# TODO(jiangziyue) : All values which should do bp are torch.Tensor?
stage_outputs = pytree_filter(lambda x: True, stage_outputs, process_types=torch.Tensor)
grad_tensors = pytree_filter(lambda x: True, grad_tensors, process_types=torch.Tensor)
# output all input's grad to producer, even it has no grad(output None)
# to make the offset aligned to the topo's record.
if grad_tensors is not None:
filtered_outputs = []
filtered_grads = []
for i, grad in enumerate(grad_tensors):
stage_output = stage_outputs[i]
if stage_output.requires_grad and grad is not None:
filtered_outputs.append(stage_output)
filtered_grads.append(grad)
stage_outputs = filtered_outputs
grad_tensors = pyobj_map(filtered_grads, fn=lambda x: x.to(self.device),
process_types=torch.Tensor) # torch rpc doesn't support args or rets in GPU
autograd.backward(stage_outputs, grad_tensors=grad_tensors)
# collect grad of input tensor
consume_result = []
if not is_first_stage:
# In current design, input mush be a flatten args.
for arg in stage_input_args:
if isinstance(arg, torch.Tensor):
consume_result.append(arg.grad)
else:
consume_result.append(None)
consume_result = pyobj_map(
consume_result, fn=lambda x: x.to('cpu'),
process_types=torch.Tensor) # torch rpc doesn't support args or rets in GPU
else:
raise TypeError(f"Unknown phase appears in _consume_work_item_by_phase {phase}")
return consume_result
def _get_store_len(self):
return f'work_list:{len(self.work_list)} output_list:{len(self.output_list)} backward_cache:{len(self.microbatch_id_to_backward_cache)} label_cache:{len(self.microbatch_id_to_labels)}'
def _get_parameter_grad_sum(self):
grad_sum = 0
for p in self.module_partition.parameters():
if p.grad is not None:
grad_sum += p.grad.sum()
return grad_sum
def _is_first_step(self, work_item: WorkItem) -> bool:
return work_item.phase == Phase.FORWARD and work_item.microbatch_id == 0
def _is_last_step(self, work_item: WorkItem) -> bool:
if work_item.forward_only:
last_phase = Phase.FORWARD
else:
last_phase = Phase.BACKWARD
is_last_phase = work_item.phase == last_phase
is_last_microbatch = work_item.microbatch_id == self.num_microbatches - 1
return is_last_phase and is_last_microbatch
def _hook_before_step(self):
pass
# install the main loop to wait for next batch input
def _wait_for_reset(self):
with self.reset_condition:
self.reset_condition.wait_for(lambda: self.reset)
self.reset = False
# do the main loop to consume ready_list
def _work_loop(self):
# for init
self._get_producer_consumer()
torch.cuda.set_device(ppg.get_local_pp_rank())
# main loop
while True:
work_item_key = self._get_work_item_key()
# move current work item to output_list to activate subscribe in advance
with self.work_list_condition_lock:
self.work_list_condition_lock.wait_for(lambda: work_item_key in self.work_list)
work_item = self.work_list[work_item_key]
with self.output_list_condition_lock:
# assert work_item_key not in self.output_list
self.output_list[work_item_key] = work_item
self.output_list_condition_lock.notify_all()
consume_result = self._consume_work_item_by_phase(work_item)
with self.work_list_condition_lock:
self.work_list.pop(work_item_key)
work_item.output.set_result(consume_result)
# if is last step in one batch reset context and do step
if self._is_last_step(work_item):
self._wait_for_reset()
# reset context and resume loop
def reset_context(self):
self.forward_times = 0
self.backward_times = 0
self.outstanding = 0
self._initialize_outstanding_range()
with self.work_list_condition_lock:
self.work_list.clear()
with self.output_list_condition_lock:
self.output_list.clear()
with self.reset_condition:
self.reset = True
self.reset_condition.notify_all()
def initialize_optimizer(self, optimizer_class: type, **kwargs):
self.optimizer: optim.Optimizer = optimizer_class(self.module_partition.parameters(), **kwargs)
def step(self):
self._hook_before_step()
self.optimizer.step()
self.optimizer.zero_grad()
class PipelineEngineBase(ABC, nn.Module):
def __init__(self,
worker_type,
partition_fn: Callable,
stage_num,
num_microbatches,
device: str,
use_1F1B=False,
chunk: int = 1,
criterion: Callable = None,
metric: Callable = None,
checkpoint: bool = False,
data_process_func: Callable = None) -> None:
super().__init__()
self.worker_type = worker_type
self.partition_fn: Callable = partition_fn
self.chunk = chunk
self.criterion = criterion
self.metric = metric
self.num_microbatches = num_microbatches
self.device = device
self.use_1F1B = use_1F1B
self.stage_num = stage_num
self.checkpoint = checkpoint
self.data_process_func = data_process_func
self.pp_rank_to_worker_rref: Dict[int, PyRRef] = dict()
self._check_argument()
self._create_pp_rank_to_rpc_worker_id()
self._create_pp_rank_to_module_partition_id()
self._init_worker()
def _check_argument(self) -> None:
# make virtual stage num
self.virtual_stage_num = self.stage_num * self.chunk
assert self.stage_num <= torch.cuda.device_count(), "stage_num must be smaller than device count!"
# check data_process_func
data_process_func = self.data_process_func
if data_process_func is not None:
assert callable(data_process_func), "data_process_func must be a function"
assert '<locals>' not in data_process_func.__repr__(), "data_process_func must be a global function"
assert '<lambda>' not in data_process_func.__repr__(), "data_process_func cannot be a lambda expression"
sig = inspect.signature(data_process_func)
assert len(
sig.parameters
) == 2, f"length of data_process_func' arguments must be 2, receive {len(sig.parameters)} arguments instead"
def _get_actual_stage_num(self) -> int:
return self.stage_num if self.chunk == 1 else self.virtual_stage_num
def _create_pp_rank_to_rpc_worker_id(self) -> None:
"""create a map from model partition to stage_id, which is useful when use_interleave is True.
e.g. If a model is splited into 4 parts, which means stage_num is 2, chunk is 2, then
pp_rank_to_rpc_worker_id = [0, 1, 0, 1], that means first and third part
of partitions will be moved to device 0 and the others to device 1
"""
stage_num = self.stage_num
actual_stage_num = self._get_actual_stage_num()
self.pp_rank_to_rpc_worker_id = [0] * actual_stage_num
for pp_rank in range(actual_stage_num):
self.pp_rank_to_rpc_worker_id[pp_rank] = pp_rank % stage_num
def _create_pp_rank_to_module_partition_id(self) -> None:
"""By default(both fill drain and 1F1B), length of model partitions equal to
actual_stage_num, so allocate model partition to corresponding stage
"""
actual_stage_num = self._get_actual_stage_num()
self.pp_rank_to_module_partition_id = [0] * actual_stage_num
for pp_rank in range(actual_stage_num):
self.pp_rank_to_module_partition_id[pp_rank] = pp_rank
def _init_worker(self) -> None:
actual_stage_num = self._get_actual_stage_num()
worker_type = self.worker_type
checkpoint = self.checkpoint
num_microbatches = self.num_microbatches
device = self.device
criterion = self.criterion
metric = self.metric
partition_fn = self.partition_fn
chunk = self.chunk
data_process_func = self.data_process_func
for pp_rank in range(len(self.pp_rank_to_rpc_worker_id)):
partition_id = self.pp_rank_to_module_partition_id[pp_rank]
partition_args = (partition_id, chunk, actual_stage_num)
rpc_worker_id = self.pp_rank_to_rpc_worker_id[pp_rank]
if device[:4] == 'cuda':
device = f'cuda:{rpc_worker_id}'
self.pp_rank_to_worker_rref[pp_rank] = rpc.remote(rpc_worker_id,
worker_type,
args=(partition_fn, partition_args, pp_rank,
actual_stage_num, num_microbatches, device,
criterion, metric, checkpoint, data_process_func))
# let each worker know global worker rref (include itself)
sync_futs = []
for pp_rank in self.pp_rank_to_worker_rref:
fut = self.pp_rank_to_worker_rref[pp_rank].rpc_async().sync_global_worker_rrefs(self.pp_rank_to_worker_rref)
sync_futs.append(fut)
for fut in sync_futs:
fut.wait()
def remote_numels(self) -> Dict[int, int]:
numels = {}
actual_stage_num = self._get_actual_stage_num()
for stage_id in range(actual_stage_num):
worker_rref = self.pp_rank_to_worker_rref[stage_id]
numel = worker_rref.rpc_sync().get_numels()
numels[stage_id] = numel
return numels
def remote_parameters(self) -> Dict[int, List[torch.Tensor]]:
parameters = {}
actual_stage_num = self._get_actual_stage_num()
for stage_id in range(actual_stage_num):
parameters[stage_id] = []
worker_rref = self.pp_rank_to_worker_rref[stage_id]
for p in worker_rref.rpc_sync().get_parameters():
parameters[stage_id].append(p)
return parameters
def remote_grad(self) -> Dict[int, List[torch.Tensor]]:
grads = {}
actual_stage_num = self._get_actual_stage_num()
for stage_id in range(actual_stage_num):
grads[stage_id] = []
worker_rref = self.pp_rank_to_worker_rref[stage_id]
for grad in worker_rref.rpc_sync().get_parameter_gradients():
grads[stage_id].append(grad)
return grads
def get_input_pp_ranks(self) -> List[int]:
return [0]
def get_output_pp_ranks(self) -> List[int]:
return [self._get_actual_stage_num() - 1]
def _consume_constraint(self, microbatch_id: int, forward_only: bool, input_pp_ranks: List[int],
output_pp_ranks: List[int], ret_future):
actual_stage_num = self._get_actual_stage_num()
use_1F1B = self.use_1F1B
if microbatch_id >= actual_stage_num:
if forward_only or not use_1F1B:
for pp_rank in output_pp_ranks:
ret_future[pp_rank][microbatch_id - actual_stage_num].wait()
else:
key = UniqueKey(microbatch_id - actual_stage_num, Phase.BACKWARD)
futs = []
for pp_rank in input_pp_ranks:
worker_rref = self.pp_rank_to_worker_rref[pp_rank]
fut = worker_rref.rpc_async().get_output_by_key(key, ref_use=True, offsets=[])
futs.append(fut)
for fut in futs:
fut.wait()
def _create_ret_future(self, output_pp_ranks: List[int]) -> Dict[int, List[Future]]:
num_microbatches = self.num_microbatches
return {pp_rank: [None] * num_microbatches for pp_rank in output_pp_ranks}
def _set_input(self, input_pp_ranks: List[int], microbatch_id: int, microbatch, forward_only: bool):
for pp_rank in input_pp_ranks:
worker_rref = self.pp_rank_to_worker_rref[pp_rank]
# TODO : add relationship between input_pp_ranks and parts of microbatch
worker_rref.remote().set_input(microbatch_id, microbatch, forward_only)
def _set_labels(self, output_pp_ranks: List[int], microbatch_id: int, microlabels):
for pp_rank in output_pp_ranks:
worker_rref = self.pp_rank_to_worker_rref[pp_rank]
# TODO : add relationship between output_pp_ranks and parts of microlabels
worker_rref.remote().set_labels(microbatch_id, microlabels)
# TODO(jiangziyue) : get model output with single value, instead of merging into last stage.
def _subscribe_forward(self, microbatch_id: int, output_pp_ranks: List[int], ret_future: Dict[int, List[Future]]):
key = UniqueKey(microbatch_id, Phase.FORWARD)
for pp_rank in output_pp_ranks:
worker_rref = self.pp_rank_to_worker_rref[pp_rank]
ret_future[pp_rank][microbatch_id] = worker_rref.rpc_async().get_output_by_key(key)
def _ensure_backward(self, forward_only: bool, input_pp_ranks: List[int]):
if not forward_only:
backward_result = []
for pp_rank in input_pp_ranks:
worker_rref = self.pp_rank_to_worker_rref[pp_rank]
key = UniqueKey(self.num_microbatches - 1, Phase.BACKWARD)
fut = worker_rref.rpc_async().get_output_by_key(
key, offsets=[]) # only ensure the res exists, no need for real data.
backward_result.append(fut)
for fut in backward_result:
fut.wait()
def _collect_forward_result(self, output_pp_ranks: List[int], ret_future: Dict[int, List[Future]]):
forward_result = []
for pp_rank in output_pp_ranks:
worker_forward_result = [None] * self.num_microbatches
for microbatch_id in range(self.num_microbatches):
ret = ret_future[pp_rank][microbatch_id].wait()
# TODO : more stable format
ret = [ret] if isinstance(ret, torch.Tensor) else ret
worker_forward_result[microbatch_id] = ret
worker_forward_result = list(zip(*worker_forward_result))
forward_result.extend(worker_forward_result)
return forward_result
def _reset_worker(self):
actual_stage_num = self._get_actual_stage_num()
reset_futs: List[Future] = []
for pp_rank in range(actual_stage_num):
worker_rref = self.pp_rank_to_worker_rref[pp_rank]
fut = worker_rref.rpc_async().reset_context()
reset_futs.append(fut)
for fut in reset_futs:
fut.wait()
def forward_backward(self, batch: torch.Tensor, labels: torch.Tensor = None, forward_only: bool = False):
batch_lengths = get_batch_lengths(batch)
batch_length = batch_lengths[0]
if labels is not None and not forward_only:
assert hasattr(
self, 'optimizer_class'), "call `initialize_optimizer` to initialize optimizer before forward_backward"
num_microbatches = self.num_microbatches
assert batch_length >= num_microbatches, "num_microbatches is greater than the size of a batch, which is illegal"
microbatch_size = math.ceil(batch_length / num_microbatches)
device = self.device
# If Chimera mode is used, then rank of down pipeline is excluded from 'input_pp_ranks' or 'output_pp_ranks'
input_pp_ranks = self.get_input_pp_ranks()
output_pp_ranks = self.get_output_pp_ranks()
# a cache to collect data and control flow
ret_future = self._create_ret_future(output_pp_ranks)
for microbatch_id in range(num_microbatches):
# control data input speed
# to prevent exceed of wait limitations
# self._consume_constraint(microbatch_id, forward_only, input_pp_ranks, output_pp_ranks, ret_future)
batch_start = microbatch_size * microbatch_id
batch_end = min(batch_start + microbatch_size, batch_length)
# set input
microbatch = split_batch(batch, batch_start, batch_end, device)
self._set_input(input_pp_ranks, microbatch_id, microbatch, forward_only)
# set labels
if labels is not None:
# microlabels = labels[microbatch_size * microbatch_id:microbatch_size * (microbatch_id + 1)]
microlabels = split_batch(labels, batch_start, batch_end, device)
self._set_labels(output_pp_ranks, microbatch_id, microlabels)
# get data asynchronously
self._subscribe_forward(microbatch_id, output_pp_ranks, ret_future)
# wait for first rank to ensure all backwards are done
self._ensure_backward(forward_only, input_pp_ranks)
# collect forward result
forward_result = self._collect_forward_result(output_pp_ranks, ret_future)
if not forward_only and hasattr(self, 'optimizer_class'):
self.step()
self._reset_worker() # reset worker attributes for next batch
return forward_result
def initialize_optimizer(self, optimizer_class: type, **kwargs):
self.optimizer_class = optimizer_class
for pp_rank in self.pp_rank_to_worker_rref:
worker_rref = self.pp_rank_to_worker_rref[pp_rank]
worker_rref.remote().initialize_optimizer(optimizer_class, **kwargs)
def step(self):
actual_stage_num = self._get_actual_stage_num()
step_futs: List[Future] = []
for pp_rank in range(actual_stage_num):
worker_rref = self.pp_rank_to_worker_rref[pp_rank]
fut = worker_rref.rpc_async().step()
step_futs.append(fut)
for fut in step_futs:
fut.wait()