ColossalAI/colossalai/pipeline/rpc/_pipeline_schedule.py
Boyuan Yao 7a58dc5ad2
Update metainfo patch branch (#2517)
* init

* rename and remove useless func

* basic chunk

* add evoformer

* align evoformer

* add meta

* basic chunk

* basic memory

* finish basic inference memory estimation

* finish memory estimation

* fix bug

* finish memory estimation

* add part of index tracer

* finish basic index tracer

* add doc string

* add doc str

* polish code

* polish code

* update active log

* polish code

* add possible region search

* finish region search loop

* finish chunk define

* support new op

* rename index tracer

* finishi codegen on msa

* redesign index tracer, add source and change compute

* pass outproduct mean

* code format

* code format

* work with outerproductmean and msa

* code style

* code style

* code style

* code style

* change threshold

* support check_index_duplicate

* support index dupilictae and update loop

* support output

* update memory estimate

* optimise search

* fix layernorm

* move flow tracer

* refactor flow tracer

* format code

* refactor flow search

* code style

* adapt codegen to prepose node

* code style

* remove abandoned function

* remove flow tracer

* code style

* code style

* reorder nodes

* finish node reorder

* update run

* code style

* add chunk select class

* add chunk select

* code style

* add chunksize in emit, fix bug in reassgin shape

* code style

* turn off print mem

* add evoformer openfold init

* init openfold

* add benchmark

* add print

* code style

* code style

* init openfold

* update openfold

* align openfold

* use max_mem to control stratge

* update source add

* add reorder in mem estimator

* improve reorder efficeincy

* support ones_like, add prompt if fit mode search fail

* fix a bug in ones like, dont gen chunk if dim size is 1

* fix bug again

* update min memory stratege, reduce mem usage by 30%

* last version of benchmark

* refactor structure

* restruct dir

* update test

* rename

* take apart chunk code gen

* close mem and code print

* code format

* rename ambiguous variable

* seperate flow tracer

* seperate input node dim search

* seperate prepose_nodes

* seperate non chunk input

* seperate reorder

* rename

* ad reorder graph

* seperate trace flow

* code style

* code style

* fix typo

* set benchmark

* rename test

* update codegen test

* Fix state_dict key missing issue of the ZeroDDP (#2363)

* Fix state_dict output for ZeroDDP duplicated parameters

* Rewrite state_dict based on get_static_torch_model

* Modify get_static_torch_model to be compatible with the lower version (ZeroDDP)

* update codegen test

* update codegen test

* add chunk search test

* code style

* add available

* [hotfix] fix gpt gemini example (#2404)

* [hotfix] fix gpt gemini example

* [example] add new assertions

* remove autochunk_available

* [workflow] added nightly release to pypi (#2403)

* add comments

* code style

* add doc for search chunk

* [doc] updated readme regarding pypi installation (#2406)

* add doc for search

* [doc] updated kernel-related optimisers' docstring (#2385)

* [doc] updated kernel-related optimisers' docstring

* polish doc

* rename trace_index to trace_indice

* rename function from index to indice

* rename

* rename in doc

* [polish] polish code for get_static_torch_model (#2405)

* [gemini] polish code

* [testing] remove code

* [gemini] make more robust

* rename

* rename

* remove useless function

* [worfklow] added coverage test (#2399)

* [worfklow] added coverage test

* polish code

* polish code

* polish code

* polish code

* polish code

* polish code

* polish code

* polish code

* add doc for trace indice

* [docker] updated Dockerfile and release workflow (#2410)

* add doc

* update doc

* add available

* change imports

* add test in import

* [workflow] refactored the example check workflow (#2411)

* [workflow] refactored the example check workflow

* polish code

* polish code

* polish code

* polish code

* polish code

* polish code

* polish code

* polish code

* polish code

* polish code

* polish code

* Update parallel_context.py (#2408)

* [hotfix] add DISTPAN argument for benchmark (#2412)

* change the benchmark config file

* change config

* revert config file

* rename distpan to distplan

* [workflow] added precommit check for code consistency (#2401)

* [workflow] added precommit check for code consistency

* polish code

* polish code

* polish code

* polish code

* polish code

* polish code

* polish code

* adapt new fx

* [workflow] added translation for non-english comments (#2414)

* [setup] refactored setup.py for dependency graph (#2413)

* change import

* update doc

* [workflow] auto comment if precommit check fails (#2417)

* [hotfix] add norm clearing for the overflow step (#2416)

* [examples] adding tflops to PaLM (#2365)

* [workflow]auto comment with test coverage report (#2419)

* [workflow]auto comment with test coverage report

* polish code

* polish yaml

* [doc] added documentation for CI/CD (#2420)

* [doc] added documentation for CI/CD

* polish markdown

* polish markdown

* polish markdown

* [example] removed duplicated stable diffusion example (#2424)

* [zero] add inference mode and its unit test (#2418)

* [workflow] report test coverage even if below threshold (#2431)

* [example] improved the clarity yof the example readme (#2427)

* [example] improved the clarity yof the example readme

* polish workflow

* polish workflow

* polish workflow

* polish workflow

* polish workflow

* polish workflow

* [ddp] add is_ddp_ignored (#2434)

[ddp] rename to is_ddp_ignored

* [workflow] make test coverage report collapsable (#2436)

* [autoparallel] add shard option (#2423)

* [fx] allow native ckpt trace and codegen. (#2438)

* [cli] provided more details if colossalai run fail (#2442)

* [autoparallel] integrate device mesh initialization into autoparallelize (#2393)

* [autoparallel] integrate device mesh initialization into autoparallelize

* add megatron solution

* update gpt autoparallel examples with latest api

* adapt beta value to fit the current computation cost

* [zero] fix state_dict and load_state_dict for ddp ignored parameters (#2443)

* [ddp] add is_ddp_ignored

[ddp] rename to is_ddp_ignored

* [zero] fix state_dict and load_state_dict

* fix bugs

* [zero] update unit test for ZeroDDP

* [example] updated the hybrid parallel tutorial (#2444)

* [example] updated the hybrid parallel tutorial

* polish code

* [zero] add warning for ignored parameters (#2446)

* [example] updated large-batch optimizer tutorial (#2448)

* [example] updated large-batch optimizer tutorial

* polish code

* polish code

* [example] fixed seed error in train_dreambooth_colossalai.py (#2445)

* [workflow] fixed the on-merge condition check (#2452)

* [workflow] automated the compatiblity test (#2453)

* [workflow] automated the compatiblity test

* polish code

* [autoparallel] update binary elementwise handler (#2451)

* [autoparallel] update binary elementwise handler

* polish

* [workflow] automated bdist wheel build (#2459)

* [workflow] automated bdist wheel build

* polish workflow

* polish readme

* polish readme

* Fix False warning in initialize.py (#2456)

* Update initialize.py

* pre-commit run check

* [examples] update autoparallel tutorial demo (#2449)

* [examples] update autoparallel tutorial demo

* add test_ci.sh

* polish

* add conda yaml

* [cli] fixed hostname mismatch error (#2465)

* [example] integrate autoparallel demo with CI (#2466)

* [example] integrate autoparallel demo with CI

* polish code

* polish code

* polish code

* polish code

* [zero] low level optim supports ProcessGroup (#2464)

* [example] update vit ci script (#2469)

* [example] update vit ci script

* [example] update requirements

* [example] update requirements

* [example] integrate seq-parallel tutorial with CI (#2463)

* [zero] polish low level optimizer (#2473)

* polish pp middleware (#2476)

Co-authored-by: Ziyue Jiang <ziyue.jiang@gmail.com>

* [example] update gpt gemini example ci test (#2477)

* [zero] add unit test for low-level zero init (#2474)

* [workflow] fixed the skip condition of  example weekly check workflow (#2481)

* [example] stable diffusion add roadmap

* add dummy test_ci.sh

* [example] stable diffusion add roadmap (#2482)

* [CI] add test_ci.sh for palm, opt and gpt (#2475)

* polish code

* [example] titans for gpt

* polish readme

* remove license

* polish code

* update readme

* [example] titans for gpt (#2484)

* [autoparallel] support origin activation ckpt on autoprallel system (#2468)

* [autochunk] support evoformer tracer (#2485)

support full evoformer tracer, which is a main module of alphafold. previously we just support a simplifed version of it.
1. support some evoformer's op in fx
2. support evoformer test
3. add repos for test code

* [example] fix requirements (#2488)

* [zero] add unit testings for hybrid parallelism  (#2486)

* [hotfix] gpt example titans bug #2493

* polish code and fix dataloader bugs

* [hotfix] gpt example titans bug #2493 (#2494)

* [fx] allow control of ckpt_codegen init (#2498)

* [fx] allow control of ckpt_codegen init

Currently in ColoGraphModule, ActivationCheckpointCodeGen will be set automatically in __init__. But other codegen can't be set if so. 
So I add an arg to control whether to set ActivationCheckpointCodeGen in __init__.

* code style

* [example] dreambooth example

* add test_ci.sh to dreambooth

* [autochunk] support autochunk on evoformer (#2497)

* Revert "Update parallel_context.py (#2408)"

This reverts commit 7d5640b9db.

* add avg partition (#2483)

Co-authored-by: Ziyue Jiang <ziyue.jiang@gmail.com>

* [auto-chunk] support extramsa (#3) (#2504)

* [utils] lazy init. (#2148)

* [utils] lazy init.

* [utils] remove description.

* [utils] complete.

* [utils] finalize.

* [utils] fix names.

* [autochunk] support parsing blocks (#2506)

* [zero] add strict ddp mode (#2508)

* [zero] add strict ddp mode

* [polish] add comments for strict ddp mode

* [zero] fix test error

* [doc] update opt and tutorial links (#2509)

* [workflow] fixed changed file detection (#2515)

Co-authored-by: oahzxl <xuanlei.zhao@gmail.com>
Co-authored-by: eric8607242 <e0928021388@gmail.com>
Co-authored-by: HELSON <c2h214748@gmail.com>
Co-authored-by: Frank Lee <somerlee.9@gmail.com>
Co-authored-by: Haofan Wang <haofanwang.ai@gmail.com>
Co-authored-by: Jiarui Fang <fangjiarui123@gmail.com>
Co-authored-by: ZijianYY <119492445+ZijianYY@users.noreply.github.com>
Co-authored-by: YuliangLiu0306 <72588413+YuliangLiu0306@users.noreply.github.com>
Co-authored-by: Super Daniel <78588128+super-dainiu@users.noreply.github.com>
Co-authored-by: ver217 <lhx0217@gmail.com>
Co-authored-by: Ziyue Jiang <ziyue.jiang97@gmail.com>
Co-authored-by: Ziyue Jiang <ziyue.jiang@gmail.com>
Co-authored-by: oahzxl <43881818+oahzxl@users.noreply.github.com>
Co-authored-by: binmakeswell <binmakeswell@gmail.com>
Co-authored-by: Fazzie-Maqianli <55798671+Fazziekey@users.noreply.github.com>
Co-authored-by: アマデウス <kurisusnowdeng@users.noreply.github.com>
2023-01-27 09:52:21 +08:00

347 lines
14 KiB
Python

import threading
from typing import Callable, Dict, List
import torch
import torch.distributed as dist
from torch._C._distributed_rpc import PyRRef
from torch.futures import Future
from colossalai.pipeline.pipeline_process_group import ppg
from colossalai.pipeline.rpc._pipeline_base import Phase, PipelineEngineBase, UniqueKey, WorkerBase, WorkItem
# Implementation of different Pipeline schedule
# <strategy>Worker defines the worker for each stage
# <strategy>PipelineEngine is the class for use
class FillDrainWorker(WorkerBase):
def _get_work_item_key(self) -> UniqueKey:
# execute backward first (if backward phase in work_list)
num_microbatches = self.num_microbatches
if self.forward_times < num_microbatches:
target_phase = Phase.FORWARD
target_microbatch_id = self.forward_times
else:
target_phase = Phase.BACKWARD
target_microbatch_id = self.backward_times
target_key = UniqueKey(target_microbatch_id, target_phase)
return target_key
class FillDrainPipelineEngine(PipelineEngineBase):
def __init__(self,
partition_fn: Callable,
stage_num: int,
num_microbatches: int,
device: str,
chunk: int = 1,
criterion: Callable = None,
metric: Callable = None,
checkpoint: bool = False,
data_process_func: Callable = None) -> None:
if chunk > 1:
assert num_microbatches % stage_num == 0, \
"if you use interleaving strategy, make sure 'num_microbatches' is a multiple of stage_num!"
use_1F1B = False
super().__init__(FillDrainWorker, partition_fn, stage_num, num_microbatches, device, use_1F1B, chunk, criterion,
metric, checkpoint, data_process_func)
class OneFOneBWorker(WorkerBase):
def _get_work_item_key(self) -> UniqueKey:
# execute backward first (if backward phase in work_list)
pp_rank = self.pp_rank
actual_stage_num = self.actual_stage_num
num_microbatches = self.num_microbatches
is_last_stage = pp_rank == actual_stage_num - 1
if self.outstanding <= self.outstanding_range[0]:
target_phase = Phase.FORWARD
target_microbatch_id = self.forward_times
elif self.outstanding >= self.outstanding_range[1]:
target_phase = Phase.BACKWARD
target_microbatch_id = self.backward_times
else:
raise ValueError("outstanding_range[1] - outstanding_range[0] must be in [0, 1]")
target_key = UniqueKey(target_microbatch_id, target_phase)
# change outstanding_range at:
# 1. forward times reach actual_stage_num, this is the end of continuous forward
# 2. forward times reach num_microbatches, this is the end of 1F1B mode
if not is_last_stage and \
target_key.phase == Phase.FORWARD:
if target_key.microbatch_id == actual_stage_num - 1 and num_microbatches > 2:
# Why need num_microbatches > 2 ? Because there is no steady stage when num_microbatches <= 2
outstanding_min = actual_stage_num - pp_rank - 1
outstanding_max = actual_stage_num - pp_rank
self.outstanding_range = (outstanding_min, outstanding_max)
if target_key.microbatch_id == num_microbatches - 1:
self.outstanding_range = (0, 0)
return target_key
class OneFOneBPipelineEngine(PipelineEngineBase):
def __init__(self,
partition_fn: Callable,
stage_num: int,
num_microbatches: int,
device: str,
chunk: int = 1,
criterion: Callable = None,
metric: Callable = None,
checkpoint: bool = False,
data_process_func: Callable = None) -> None:
if chunk > 1:
assert num_microbatches % stage_num == 0, \
"if you use interleaving strategy, make sure 'num_microbatches' is a multiple of stage_num!"
# assert num_microbatches > stage_num * chunk, "num_microbatches must be greater than stage_num * chunk"
use_1F1B = True
super().__init__(OneFOneBWorker, partition_fn, stage_num, num_microbatches, device, use_1F1B, chunk, criterion,
metric, checkpoint, data_process_func)
class ChimeraWorker(WorkerBase):
def _get_producer_consumer(self) -> None:
rank = self.pp_rank
min_pp_rank = (rank // self.actual_stage_num) * self.actual_stage_num
max_pp_rank = min_pp_rank + self.actual_stage_num - 1
assert self.producer_stage_ids is None, f"all the producers of rank {rank} has been subscribed"
assert self.consumer_stage_ids is None, f"all the consumers of rank {rank} has been subscribed"
# should be aranged in order, the order of the input of current forward
self.producer_stage_ids = []
self.consumer_stage_ids = []
# Just for demo
prev_rank = rank - 1
next_rank = rank + 1
if prev_rank >= min_pp_rank:
self.producer_stage_ids.append(prev_rank)
if next_rank <= max_pp_rank:
self.consumer_stage_ids.append(next_rank)
def _get_work_item_key(self) -> UniqueKey:
pp_rank = self.pp_rank
stage_num = self.actual_stage_num
real_microbatch_num = self.num_microbatches // 2
forward_block_size = 1 if self.num_microbatches < stage_num else self.num_microbatches // stage_num
forward_block_num = self.forward_times // forward_block_size
if self.forward_times >= real_microbatch_num or \
((pp_rank + 1) % stage_num == 0 and forward_block_num > self.backward_times):
target_phase = Phase.BACKWARD
target_microbatch_id = self.backward_times
else: # others
target_phase = Phase.FORWARD
target_microbatch_id = self.forward_times
# In up pipeline, microbatch_id to consume is 0, 2, 4 (2n)
# In down pipeline, microbatch_id to consume is 1, 3, 5 (2n + 1)
real_target_microbatch_id = target_microbatch_id * 2
if pp_rank >= stage_num:
real_target_microbatch_id += 1
target_key = UniqueKey(real_target_microbatch_id, target_phase)
with self.work_list_condition_lock:
self.work_list_condition_lock.wait_for(lambda: target_key in self.work_list)
return target_key
def _initialize_partition(self):
# In order to ensure the down pipeline share the same parameter
# with the up pipeline, partition of down partition will be copied
# from corresponding up stage
pp_rank = self.pp_rank
stage_num = self.actual_stage_num
device = self.device
if pp_rank < stage_num:
super()._initialize_partition()
else:
# if it is down pipeline, create partition by origin method
co_up_pp_worker_rref = self.pp_rank_to_worker_rref[pp_rank - stage_num]
# get the coresponding model state dict and wait for its init
state_dict = co_up_pp_worker_rref.rpc_sync().get_partition_state_dict()
super()._initialize_partition()
self.module_partition.load_state_dict(state_dict)
# init group for chimera in ppg
ppg.get_chimera_all_reduce_group(pp_rank)
# lock for step sync
self.step_sync_lock = threading.Lock()
self.step_sync_lock.acquire()
self.have_grad_lock = threading.Lock()
self.have_grad_lock.acquire()
def _get_lock_gradient(self):
self.have_grad_lock.acquire()
grads = self.get_parameter_gradients()
self.step_sync_lock.release()
return grads
def is_first_stage(self):
return (self.pp_rank % self.actual_stage_num) == 0
def is_last_stage(self):
return (self.pp_rank % self.actual_stage_num) == self.actual_stage_num - 1
def _is_last_step(self, work_item: WorkItem) -> bool:
if work_item.forward_only:
last_phase = Phase.FORWARD
else:
last_phase = Phase.BACKWARD
is_last_phase = work_item.phase == last_phase
last_microbatch_id = self.num_microbatches - 1
if self.pp_rank < self.actual_stage_num:
last_microbatch_id -= 1
is_last_microbatch = work_item.microbatch_id == last_microbatch_id
return is_last_phase and is_last_microbatch
def _get_step_order(self) -> List[int]:
# TODO : If you want to extend it to multi head chimera, overwrite here
stage_num = self.actual_stage_num
pp_rank = self.pp_rank
# pp_rank in the same device
local_device_pp_ranks = [pp_rank, stage_num * 2 - pp_rank - 1]
local_device_pp_ranks.sort(reverse=min(local_device_pp_ranks) < stage_num // 2)
return local_device_pp_ranks
def _hook_before_step(self):
self.have_grad_lock.release()
pp_rank = self.pp_rank
stage_num = self.actual_stage_num
co_pp_rank = (pp_rank + stage_num) % (2 * stage_num)
# if currrent pp_rank is not the first to do step
# wait its previous pp_rank finish step
grads = self.get_parameter_gradients()
# send
co_worker = self.pp_rank_to_worker_rref[co_pp_rank]
co_grads = co_worker.rpc_sync()._get_lock_gradient()
# sync
self.step_sync_lock.acquire()
for i in range(len(grads)):
grads[i] += co_grads[i]
class ChimeraPipelineEngine(PipelineEngineBase):
def __init__(self,
partition_fn: Callable,
stage_num: int,
num_microbatches: int,
device: str,
criterion: Callable = None,
metric: Callable = None,
checkpoint: bool = False,
data_process_func: Callable = None) -> None:
assert num_microbatches % stage_num == 0, \
"In Chimera, num_microbatches must be the multiply of stage_num!"
use_1F1B = False
chunk = 1
super().__init__(ChimeraWorker, partition_fn, stage_num, num_microbatches, device, use_1F1B, chunk, criterion,
metric, checkpoint, data_process_func)
def _consume_constraint(self, microbatch_id: int, forward_only: bool, input_pp_ranks: List[int],
output_pp_ranks: List[int], ret_future):
pass
def _create_pp_rank_to_rpc_worker_id(self) -> None:
stage_num = self.stage_num
self.pp_rank_to_rpc_worker_id = [0] * (stage_num * 2)
for pp_rank in range(stage_num):
self.pp_rank_to_rpc_worker_id[pp_rank] = pp_rank
self.pp_rank_to_rpc_worker_id[pp_rank + stage_num] = stage_num - pp_rank - 1
def _create_pp_rank_to_module_partition_id(self) -> None:
stage_num = self.stage_num
self.pp_rank_to_module_partition_id = [0] * (stage_num * 2)
for pp_rank in range(stage_num):
self.pp_rank_to_module_partition_id[pp_rank] = pp_rank
self.pp_rank_to_module_partition_id[pp_rank + stage_num] = pp_rank
def _create_ret_future(self, output_pp_ranks: List[int]) -> Dict[int, List[Future]]:
num_microbatches = self.num_microbatches
stage_num = self.stage_num
up_ret_future = {pp_rank: [None] * num_microbatches for pp_rank in output_pp_ranks}
down_ret_future = {pp_rank + stage_num: [None] * num_microbatches for pp_rank in output_pp_ranks}
# merge up and down
return {**up_ret_future, **down_ret_future}
def _set_input(self, input_pp_ranks: List[int], microbatch_id: int, microbatch, forward_only: bool):
# offset is 0 for all the ranks in up pipeline
# offset is stage_num for all the ranks in down pipeline
offset = (microbatch_id % 2) * self.stage_num
for pp_rank in input_pp_ranks:
worker_rref = self.pp_rank_to_worker_rref[pp_rank + offset]
worker_rref.remote().set_input(microbatch_id, microbatch, forward_only)
def _set_labels(self, output_pp_ranks: List[int], microbatch_id: int, microlabels):
# offset is 0 for all the ranks in up pipeline
# offset is stage_num for all the ranks in down pipeline
offset = (microbatch_id % 2) * self.stage_num
for pp_rank in output_pp_ranks:
worker_rref = self.pp_rank_to_worker_rref[pp_rank + offset]
worker_rref.remote().set_labels(microbatch_id, microlabels)
def _subscribe_forward(self, microbatch_id: int, output_pp_ranks: List[int], ret_future: Dict[int, List[Future]]):
key = UniqueKey(microbatch_id, Phase.FORWARD)
offset = (microbatch_id % 2) * self.stage_num
for pp_rank in output_pp_ranks:
worker_rref = self.pp_rank_to_worker_rref[pp_rank + offset]
ret_future[pp_rank + offset][microbatch_id] = worker_rref.rpc_async().get_output_by_key(key)
def _ensure_backward(self, forward_only: bool, input_pp_ranks: List[int]):
stage_num = self.stage_num
num_microbatches = self.num_microbatches
if not forward_only:
for pp_rank in input_pp_ranks:
up_last_microbatch_id = num_microbatches - 2
down_last_microbatch_id = num_microbatches - 1
up_worker_rref = self.pp_rank_to_worker_rref[pp_rank]
down_worker_rref = self.pp_rank_to_worker_rref[pp_rank + stage_num]
up_key = UniqueKey(up_last_microbatch_id, Phase.BACKWARD)
down_key = UniqueKey(down_last_microbatch_id, Phase.BACKWARD)
up_worker_rref.rpc_sync().get_output_by_key(up_key)
down_worker_rref.rpc_sync().get_output_by_key(down_key)
def _collect_forward_result(self, output_pp_ranks: List[int], ret_future: Dict[PyRRef, List[Future]]):
"""Logic of collection of forward in Chimera.
Currently, only one input one output model is supported
"""
stage_num = self.stage_num
forward_result = []
for pp_rank in output_pp_ranks:
worker_forward_result = [None] * self.num_microbatches
for microbatch_id in range(self.num_microbatches):
offset = (microbatch_id % 2) * stage_num
ret = ret_future[pp_rank + offset][microbatch_id].wait()
ret = [ret] if isinstance(ret, torch.Tensor) else ret
worker_forward_result[microbatch_id] = ret
worker_forward_result = list(zip(*worker_forward_result))
forward_result.extend(worker_forward_result)
return forward_result