[pipeline/chimera] test chimera | fix bug of initializing (#1615)

* [pipeline/tuning] improve dispatch performance both time and space cost

* [pipeline/converge] add interface for testing convergence

* [NFC] polish colossalai/utils/multi_tensor_apply/multi_tensor_apply.py code style

* Update PipelineBase.py

* [pipeline/chimera] reconstruct PipelineBase and Worker to support more feasible custom schedule | finish Chimera

* [pipeline/chimera] test chimera | fix bug of initializing
This commit is contained in:
Kirigaya Kazuto 2022-09-20 18:00:39 +08:00 committed by GitHub
parent 504ff1d101
commit 170fa81095
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
13 changed files with 342 additions and 144 deletions

View File

@ -110,6 +110,7 @@ class PipelinableContext(InsertPostInitMethodToModuleSubClasses):
name_list.append((name, param)) name_list.append((name, param))
for name, param in name_list: for name, param in name_list:
if hasattr(module, name):
delattr(module, name) delattr(module, name)
setattr(module, name, ColoParameter.from_torch_tensor(tensor=param.data, requires_grad=param.requires_grad)) setattr(module, name, ColoParameter.from_torch_tensor(tensor=param.data, requires_grad=param.requires_grad))

View File

@ -1,5 +1,6 @@
from typing import List, Dict, Tuple from typing import List, Dict, Tuple
import os import os
import threading
from torch.distributed import rpc from torch.distributed import rpc
import torch.distributed as dist import torch.distributed as dist
@ -10,13 +11,17 @@ from colossalai.tensor import ProcessGroup
class PipelineProcessGroup: class PipelineProcessGroup:
# TODO : flexible API for DP size and TP size # TODO : flexible API for DP size and TP size
# In the future design mode, dp_degree and tp_degree should be removed # In the future design mode, dp_degree and tp_degree should be removed
def __init__(self, def __init__(self) -> None:
self.is_initialize = False
def set_global_info(self,
rank: int, rank: int,
world_size: int, world_size: int,
dp_degree: int = 1, dp_degree: int = 1,
tp_degree: int = 1, tp_degree: int = 1,
num_worker_threads: int = 1, num_worker_threads: int = 1,
device: str = "cuda") -> None: device: str = "cuda") -> None:
device_mesh_size = dp_degree * tp_degree device_mesh_size = dp_degree * tp_degree
assert world_size % device_mesh_size == 0, "world_size must be the multiple of dp_degree * tp_degree !!!" assert world_size % device_mesh_size == 0, "world_size must be the multiple of dp_degree * tp_degree !!!"
self._num_worker_threads = num_worker_threads self._num_worker_threads = num_worker_threads
@ -42,6 +47,11 @@ class PipelineProcessGroup:
self._is_first_pp_rank = self._pp_rank == 0 self._is_first_pp_rank = self._pp_rank == 0
self._is_last_pp_rank = self._pp_rank == self._stage_num - 1 self._is_last_pp_rank = self._pp_rank == self._stage_num - 1
self.is_initialize = True
# lock
self.chimera_lock = threading.Lock()
def _initialize_process_group(self): def _initialize_process_group(self):
stage_num = self.get_stage_num() stage_num = self.get_stage_num()
if stage_num == 1: if stage_num == 1:
@ -133,3 +143,25 @@ class PipelineProcessGroup:
def get_tp_global_ranks(self): def get_tp_global_ranks(self):
pass pass
def get_chimera_all_reduce_group(self, pp_rank: int):
with self.chimera_lock:
if not hasattr(self, 'chimera_groups'):
world_size = self.get_world_size()
stage_num = self.get_stage_num()
assert world_size % 2 == 0, 'world_size must be even in chimera!'
self.chimera_groups = {}
for rank in range(world_size // 2):
pair = [rank, world_size - 1 - rank]
group = dist.new_group(pair)
self.chimera_groups[pair[0]] = group
self.chimera_groups[pair[1]] = group
self.chimera_groups[pair[0] + stage_num] = group
self.chimera_groups[pair[1] + stage_num] = group
self.chimera_step_lock = threading.Lock()
self.chimera_step_lock.acquire()
return self.chimera_groups[pp_rank]
ppg = PipelineProcessGroup()

View File

@ -0,0 +1,3 @@
from ._pipeline_schedule import FillDrainPipelineEngine, OneFOneBPipelineEngine, ChimeraPipelineEngine
__all__ = ['FillDrainPipelineEngine', 'OneFOneBPipelineEngine', 'ChimeraPipelineEngine']

View File

@ -139,7 +139,8 @@ class BackwardCache:
class WorkerBase(ABC): class WorkerBase(ABC):
def __init__(self, def __init__(self,
module_partition: nn.Module, partition_fn: Callable,
partition_args: tuple,
pp_rank: int, pp_rank: int,
actual_stage_num: int, actual_stage_num: int,
num_microbatches: int, num_microbatches: int,
@ -165,21 +166,22 @@ class WorkerBase(ABC):
# rref of other workers # rref of other workers
self.pp_rank_to_worker_rref: Dict[int, PyRRef] = None self.pp_rank_to_worker_rref: Dict[int, PyRRef] = None
# lock for the list
self._initialize_lock()
# topology info # topology info
self.producer_stage_ids: List[int] = None self.producer_stage_ids: List[int] = None
self.consumer_stage_ids: List[int] = None self.consumer_stage_ids: List[int] = None
# module partitions # module partitions
self.module_partition = module_partition.to(device) self.partition_fn = partition_fn
self.partition_args = partition_args
self.criterion = criterion self.criterion = criterion
self.metric = metric self.metric = metric
# context to maintain loop # context to maintain loop
self._initialize_context_container() self._initialize_context_container()
# lock for the list
self._initialize_lock()
# main loop # main loop
self.main_loop_thread = threading.Thread(target=self._work_loop, name=f'rank_{pp_rank}', daemon=True) self.main_loop_thread = threading.Thread(target=self._work_loop, name=f'rank_{pp_rank}', daemon=True)
self.main_loop_thread.start() self.main_loop_thread.start()
@ -202,20 +204,37 @@ class WorkerBase(ABC):
self.output_list: Dict[UniqueKey, WorkItem] = dict() self.output_list: Dict[UniqueKey, WorkItem] = dict()
def _initialize_lock(self): def _initialize_lock(self):
self.partition_condition_lock = threading.Condition(threading.Lock())
self.work_list_condition_lock = threading.Condition(threading.Lock()) self.work_list_condition_lock = threading.Condition(threading.Lock())
self.output_list_condition_lock = threading.Condition(threading.Lock()) self.output_list_condition_lock = threading.Condition(threading.Lock())
self.label_lock = threading.Condition(threading.Lock()) self.label_lock = threading.Condition(threading.Lock())
def _initialize_partition(self):
partition_fn = self.partition_fn
partition_args = self.partition_args
device = self.device
with self.partition_condition_lock:
self.module_partition: nn.Module = partition_fn(*partition_args).to(device)
self.partition_condition_lock.notify_all()
def sync_global_worker_rrefs(self, pp_rank_to_worker_rref: Dict[int, PyRRef]) -> None: def sync_global_worker_rrefs(self, pp_rank_to_worker_rref: Dict[int, PyRRef]) -> None:
assert self.pp_rank_to_worker_rref is None, f"in rank {self.pp_rank}, worker has sync global workers rrefs" assert self.pp_rank_to_worker_rref is None, f"in rank {self.pp_rank}, worker has sync global workers rrefs"
assert pp_rank_to_worker_rref is not None, "stage_to_workers must be a dict instead of None" assert pp_rank_to_worker_rref is not None, "stage_to_workers must be a dict instead of None"
self.pp_rank_to_worker_rref = pp_rank_to_worker_rref self.pp_rank_to_worker_rref = pp_rank_to_worker_rref
# for some schedule need the other worker's info to initialise partition (like Chimera)
# construction of partition is executed after the registion of pp_rank_to_worker_rref
self._initialize_partition()
def get_output_by_key(self, key: UniqueKey) -> Any: def get_output_by_key(self, key: UniqueKey) -> Any:
with self.output_list_condition_lock: with self.output_list_condition_lock:
self.output_list_condition_lock.wait_for(lambda: key in self.output_list) self.output_list_condition_lock.wait_for(lambda: key in self.output_list)
output_work_item = self.output_list[key] output_work_item = self.output_list[key]
output = output_work_item.output.wait()
output = output_work_item.output
if isinstance(output, Future):
output = output.wait()
# color_debug(f'rank {self.pp_rank}, output {type(output)}', 'get output', 'red') # color_debug(f'rank {self.pp_rank}, output {type(output)}', 'get output', 'red')
output_work_item.refcount += 1 output_work_item.refcount += 1
@ -231,6 +250,16 @@ class WorkerBase(ABC):
def get_parameter_gradients(self) -> List[torch.Tensor]: def get_parameter_gradients(self) -> List[torch.Tensor]:
return [p.grad for p in self.module_partition.parameters()] return [p.grad for p in self.module_partition.parameters()]
def get_partition(self):
with self.partition_condition_lock:
self.partition_condition_lock.wait_for(lambda: hasattr(self, 'module_partition'))
return self.module_partition
def get_partition_state_dict(self):
with self.partition_condition_lock:
self.partition_condition_lock.wait_for(lambda: hasattr(self, 'module_partition'))
return self.module_partition.state_dict()
# just for first pp_rank # just for first pp_rank
def set_input(self, microbatch_id: int, microbatch: Tuple[Any], forward_only: bool): def set_input(self, microbatch_id: int, microbatch: Tuple[Any], forward_only: bool):
assert self.consumer_stage_ids is not None assert self.consumer_stage_ids is not None
@ -520,6 +549,15 @@ class WorkerBase(ABC):
is_last_microbatch = work_item.microbatch_id == self.num_microbatches - 1 is_last_microbatch = work_item.microbatch_id == self.num_microbatches - 1
return is_last_phase and is_last_microbatch return is_last_phase and is_last_microbatch
def _hook_before_step(self):
pass
def _reset_context(self):
self.forward_times = 0
self.backward_times = 0
self.outstanding = 0
self._initialize_outstanding_range()
# do the main loop to consume ready_list # do the main loop to consume ready_list
def _work_loop(self): def _work_loop(self):
# for init # for init
@ -545,19 +583,17 @@ class WorkerBase(ABC):
consume_result = self._consume_work_item_by_phase(work_item) consume_result = self._consume_work_item_by_phase(work_item)
color_debug( color_debug(
f'rank_{self.pp_rank} [{work_item.phase}] finish consuming, result is {tensor_shape_list(consume_result)} {self._get_store_len()}', f'rank_{self.pp_rank} [{work_item.phase}] finish consuming, result is {tensor_shape_list(consume_result)} {self._get_store_len()} | {self.work_list.keys()} | {self.output_list.keys()}',
'work loop', 'green') 'work loop', 'green')
work_item.output.set_result(consume_result) work_item.output.set_result(consume_result)
# if is last step in one batch reset context and do step # if is last step in one batch reset context and do step
if self._is_last_step(work_item): if self._is_last_step(work_item):
self._hook_before_step()
if hasattr(self, 'optimizer') and not work_item.forward_only: if hasattr(self, 'optimizer') and not work_item.forward_only:
self.step() self.step()
self.forward_times = 0 self._reset_context()
self.backward_times = 0
self.outstanding = 0
self._initialize_outstanding_range()
def initialize_optimizer(self, optimizer_class: type, **kwargs): def initialize_optimizer(self, optimizer_class: type, **kwargs):
self.optimizer: optim.Optimizer = optimizer_class(self.module_partition.parameters(), **kwargs) self.optimizer: optim.Optimizer = optimizer_class(self.module_partition.parameters(), **kwargs)
@ -577,7 +613,7 @@ class PipelineEngineBase(ABC, nn.Module):
def __init__(self, def __init__(self,
worker_type, worker_type,
module_partitions, partition_fn: Callable,
stage_num, stage_num,
num_microbatches, num_microbatches,
device: str, device: str,
@ -588,7 +624,7 @@ class PipelineEngineBase(ABC, nn.Module):
checkpoint: bool = False) -> None: checkpoint: bool = False) -> None:
super().__init__() super().__init__()
self.worker_type = worker_type self.worker_type = worker_type
self.module_partitions: List[nn.Module] = module_partitions self.partition_fn: Callable = partition_fn
self.chunk = chunk self.chunk = chunk
self.criterion = criterion self.criterion = criterion
self.metric = metric self.metric = metric
@ -609,18 +645,15 @@ class PipelineEngineBase(ABC, nn.Module):
def _check_argument(self) -> None: def _check_argument(self) -> None:
self.virtual_stage_num = self.stage_num * self.chunk self.virtual_stage_num = self.stage_num * self.chunk
assert self.stage_num <= torch.cuda.device_count(), "stage_num must be smaller than device count!" assert self.stage_num <= torch.cuda.device_count(), "stage_num must be smaller than device count!"
assert self.virtual_stage_num == len(
self.module_partitions), "stage_num * chunk must be equal to length of model partition!"
def _get_actual_stage_num(self) -> int: def _get_actual_stage_num(self) -> int:
return self.stage_num if self.chunk == 1 else self.virtual_stage_num return self.stage_num if self.chunk == 1 else self.virtual_stage_num
def _create_pp_rank_to_rpc_worker_id(self) -> None: def _create_pp_rank_to_rpc_worker_id(self) -> None:
"""create a map from model partition to stage_id, which is useful when use_interleave is True. """create a map from model partition to stage_id, which is useful when use_interleave is True.
e.g. If a model is splited into 4 parts, which means len(self.module_partitions) == 3. e.g. If a model is splited into 4 parts, which means stage_num is 2, chunk is 2, then
stage_num is 2, chunk is 2, then pp_rank_to_rpc_worker_id = [0, 1, 0, 1], that means first and third part pp_rank_to_rpc_worker_id = [0, 1, 0, 1], that means first and third part
of partitions will be moved to device 0 and the others to device 1 of partitions will be moved to device 0 and the others to device 1
""" """
stage_num = self.stage_num stage_num = self.stage_num
@ -647,26 +680,34 @@ class PipelineEngineBase(ABC, nn.Module):
device = self.device device = self.device
criterion = self.criterion criterion = self.criterion
metric = self.metric metric = self.metric
partition_fn = self.partition_fn
chunk = self.chunk
for pp_rank in range(len(self.pp_rank_to_rpc_worker_id)): for pp_rank in range(len(self.pp_rank_to_rpc_worker_id)):
module_partition_id = self.pp_rank_to_module_partition_id[pp_rank] partition_id = self.pp_rank_to_module_partition_id[pp_rank]
partition_args = (partition_id, chunk, actual_stage_num)
rpc_worker_id = self.pp_rank_to_rpc_worker_id[pp_rank] rpc_worker_id = self.pp_rank_to_rpc_worker_id[pp_rank]
if device[:4] == 'cuda': if device[:4] == 'cuda':
device = f'cuda:{rpc_worker_id}' device = f'cuda:{rpc_worker_id}'
module_partition = self.module_partitions[module_partition_id]
self.pp_rank_to_worker_rref[pp_rank] = rpc.remote(rpc_worker_id, self.pp_rank_to_worker_rref[pp_rank] = rpc.remote(rpc_worker_id,
worker_type, worker_type,
args=(module_partition, pp_rank, actual_stage_num, args=(partition_fn, partition_args, pp_rank,
num_microbatches, device, criterion, metric, actual_stage_num, num_microbatches, device,
checkpoint)) criterion, metric, checkpoint))
# let each worker know global worker rref (include itself) # let each worker know global worker rref (include itself)
sync_futs = []
for pp_rank in self.pp_rank_to_worker_rref: for pp_rank in self.pp_rank_to_worker_rref:
self.pp_rank_to_worker_rref[pp_rank].rpc_sync().sync_global_worker_rrefs(self.pp_rank_to_worker_rref) fut = self.pp_rank_to_worker_rref[pp_rank].rpc_async().sync_global_worker_rrefs(self.pp_rank_to_worker_rref)
sync_futs.append(fut)
for fut in sync_futs:
fut.wait()
def remote_parameters(self) -> Dict[int, List[torch.Tensor]]: def remote_parameters(self) -> Dict[int, List[torch.Tensor]]:
parameters = {} parameters = {}
for stage_id in self.pp_rank_to_worker_rref: actual_stage_num = self._get_actual_stage_num()
for stage_id in range(actual_stage_num):
parameters[stage_id] = [] parameters[stage_id] = []
worker_rref = self.pp_rank_to_worker_rref[stage_id] worker_rref = self.pp_rank_to_worker_rref[stage_id]
for p in worker_rref.rpc_sync().get_parameters(): for p in worker_rref.rpc_sync().get_parameters():
@ -675,7 +716,8 @@ class PipelineEngineBase(ABC, nn.Module):
def remote_grad(self) -> Dict[int, List[torch.Tensor]]: def remote_grad(self) -> Dict[int, List[torch.Tensor]]:
grads = {} grads = {}
for stage_id in self.pp_rank_to_worker_rref: actual_stage_num = self._get_actual_stage_num()
for stage_id in range(actual_stage_num):
grads[stage_id] = [] grads[stage_id] = []
worker_rref = self.pp_rank_to_worker_rref[stage_id] worker_rref = self.pp_rank_to_worker_rref[stage_id]
for grad in worker_rref.rpc_sync().get_parameter_gradients(): for grad in worker_rref.rpc_sync().get_parameter_gradients():
@ -784,7 +826,7 @@ class PipelineEngineBase(ABC, nn.Module):
# collect forward result # collect forward result
forward_result = self._collect_forward_result(output_pp_ranks, ret_future) forward_result = self._collect_forward_result(output_pp_ranks, ret_future)
if not forward_only and labels is not None: if not forward_only and hasattr(self, 'optimizer_class'):
# wait for all step # wait for all step
for pp_rank in self.pp_rank_to_worker_rref: for pp_rank in self.pp_rank_to_worker_rref:
worker_rref = self.pp_rank_to_worker_rref[pp_rank] worker_rref = self.pp_rank_to_worker_rref[pp_rank]
@ -793,9 +835,8 @@ class PipelineEngineBase(ABC, nn.Module):
return forward_result return forward_result
def initialize_optimizer(self, optimizer_class: type, **kwargs): def initialize_optimizer(self, optimizer_class: type, **kwargs):
actual_stage_num = self._get_actual_stage_num()
self.optimizer_class = optimizer_class self.optimizer_class = optimizer_class
for pp_rank in range(actual_stage_num): for pp_rank in self.pp_rank_to_worker_rref:
worker_rref = self.pp_rank_to_worker_rref[pp_rank] worker_rref = self.pp_rank_to_worker_rref[pp_rank]
worker_rref.remote().initialize_optimizer(optimizer_class, **kwargs) worker_rref.remote().initialize_optimizer(optimizer_class, **kwargs)

View File

@ -1,10 +1,12 @@
from typing import List, Callable, Dict from typing import List, Callable, Dict
import torch.nn as nn import torch
import torch.distributed as dist
from torch.futures import Future from torch.futures import Future
from torch._C._distributed_rpc import PyRRef from torch._C._distributed_rpc import PyRRef
from colossalai.pipeline.rpc._pipeline_base import PipelineEngineBase, WorkerBase, UniqueKey, Phase from colossalai.pipeline.rpc._pipeline_base import PipelineEngineBase, WorkerBase, UniqueKey, Phase, WorkItem
from colossalai.pipeline.pipeline_process_group import ppg
# Implementation of different Pipeline schedule # Implementation of different Pipeline schedule
# <strategy>Worker defines the worker for each stage # <strategy>Worker defines the worker for each stage
@ -35,7 +37,7 @@ class FillDrainWorker(WorkerBase):
class FillDrainPipelineEngine(PipelineEngineBase): class FillDrainPipelineEngine(PipelineEngineBase):
def __init__(self, def __init__(self,
module_partitions: List[nn.Module], partition_fn: Callable,
stage_num: int, stage_num: int,
num_microbatches: int, num_microbatches: int,
device: str, device: str,
@ -49,8 +51,8 @@ class FillDrainPipelineEngine(PipelineEngineBase):
"if you use interleaving strategy, make sure 'num_microbatches' is a multiple of stage_num!" "if you use interleaving strategy, make sure 'num_microbatches' is a multiple of stage_num!"
use_1F1B = False use_1F1B = False
super().__init__(FillDrainWorker, module_partitions, stage_num, num_microbatches, device, use_1F1B, chunk, super().__init__(FillDrainWorker, partition_fn, stage_num, num_microbatches, device, use_1F1B, chunk, criterion,
criterion, metric, checkpoint) metric, checkpoint)
class OneFOneBWorker(WorkerBase): class OneFOneBWorker(WorkerBase):
@ -94,7 +96,7 @@ class OneFOneBWorker(WorkerBase):
class OneFOneBPipelineEngine(PipelineEngineBase): class OneFOneBPipelineEngine(PipelineEngineBase):
def __init__(self, def __init__(self,
module_partitions: List[nn.Module], partition_fn: Callable,
stage_num: int, stage_num: int,
num_microbatches: int, num_microbatches: int,
device: str, device: str,
@ -106,10 +108,11 @@ class OneFOneBPipelineEngine(PipelineEngineBase):
if chunk > 1: if chunk > 1:
assert num_microbatches % stage_num == 0, \ assert num_microbatches % stage_num == 0, \
"if you use interleaving strategy, make sure 'num_microbatches' is a multiple of stage_num!" "if you use interleaving strategy, make sure 'num_microbatches' is a multiple of stage_num!"
assert num_microbatches > stage_num * chunk, "num_microbatches must be greater than stage_num * chunk"
use_1F1B = True use_1F1B = True
super().__init__(OneFOneBWorker, module_partitions, stage_num, num_microbatches, device, use_1F1B, chunk, super().__init__(OneFOneBWorker, partition_fn, stage_num, num_microbatches, device, use_1F1B, chunk, criterion,
criterion, metric, checkpoint) metric, checkpoint)
class ChimeraWorker(WorkerBase): class ChimeraWorker(WorkerBase):
@ -139,21 +142,16 @@ class ChimeraWorker(WorkerBase):
stage_num = self.actual_stage_num stage_num = self.actual_stage_num
real_microbatch_num = self.num_microbatches // 2 real_microbatch_num = self.num_microbatches // 2
if self.forward_times < real_microbatch_num: forward_block_size = 1 if self.num_microbatches < stage_num else self.num_microbatches // stage_num
if (pp_rank + 1) % stage_num == 0: # last rank forward_block_num = self.forward_times // forward_block_size
forward_blocks = self.forward_times // (self.num_microbatches // stage_num)
if forward_blocks > self.backward_times: if self.forward_times >= real_microbatch_num or \
((pp_rank + 1) % stage_num == 0 and forward_block_num > self.backward_times):
target_phase = Phase.BACKWARD target_phase = Phase.BACKWARD
target_microbatch_id = self.backward_times target_microbatch_id = self.backward_times
else:
target_phase = Phase.FORWARD
target_microbatch_id = self.forward_times
else: # others else: # others
target_phase = Phase.FORWARD target_phase = Phase.FORWARD
target_microbatch_id = self.forward_times target_microbatch_id = self.forward_times
else:
target_phase = Phase.BACKWARD
target_microbatch_id = self.backward_times
# In up pipeline, microbatch_id to consume is 0, 2, 4 (2n) # In up pipeline, microbatch_id to consume is 0, 2, 4 (2n)
# In down pipeline, microbatch_id to consume is 1, 3, 5 (2n + 1) # In down pipeline, microbatch_id to consume is 1, 3, 5 (2n + 1)
@ -164,22 +162,85 @@ class ChimeraWorker(WorkerBase):
with self.work_list_condition_lock: with self.work_list_condition_lock:
self.work_list_condition_lock.wait_for(lambda: target_key in self.work_list) self.work_list_condition_lock.wait_for(lambda: target_key in self.work_list)
return target_key return target_key
def _initialize_partition(self):
# In order to ensure the down pipeline share the same parameter
# with the up pipeline, partition of down partition will be copied
# from corresponding up stage
pp_rank = self.pp_rank
stage_num = self.actual_stage_num
device = self.device
if pp_rank < stage_num:
super()._initialize_partition()
else:
# if it is down pipeline, create partition by origin method
co_up_pp_worker_rref = self.pp_rank_to_worker_rref[pp_rank - stage_num]
# get the coresponding model state dict and wait for its init
state_dict = co_up_pp_worker_rref.rpc_sync().get_partition_state_dict()
super()._initialize_partition()
self.module_partition.load_state_dict(state_dict)
# init group for chimera in ppg
ppg.get_chimera_all_reduce_group(pp_rank)
def is_first_stage(self): def is_first_stage(self):
return (self.pp_rank % self.actual_stage_num) == 0 return (self.pp_rank % self.actual_stage_num) == 0
def is_last_stage(self): def is_last_stage(self):
return (self.pp_rank % self.actual_stage_num) == self.actual_stage_num - 1 return (self.pp_rank % self.actual_stage_num) == self.actual_stage_num - 1
def _is_last_step(self, work_item: WorkItem) -> bool:
if work_item.forward_only:
last_phase = Phase.FORWARD
else:
last_phase = Phase.BACKWARD
is_last_phase = work_item.phase == last_phase
last_microbatch_id = self.num_microbatches - 1
if self.pp_rank < self.actual_stage_num:
last_microbatch_id -= 1
is_last_microbatch = work_item.microbatch_id == last_microbatch_id
return is_last_phase and is_last_microbatch
def _get_step_order(self) -> List[int]:
# TODO : If you want to extend it to multi head chimera, overwrite here
stage_num = self.actual_stage_num
pp_rank = self.pp_rank
# pp_rank in the same device
local_device_pp_ranks = [pp_rank, stage_num * 2 - pp_rank - 1]
local_device_pp_ranks.sort(reverse=min(local_device_pp_ranks) < stage_num // 2)
return local_device_pp_ranks
def _hook_before_step(self):
pp_rank = self.pp_rank
orders = self._get_step_order()
step_index = orders.index(pp_rank)
# if currrent pp_rank is not the first to do step
# wait its previous pp_rank finish step
all_reduce_group = ppg.get_chimera_all_reduce_group(self.pp_rank)
grads = self.get_parameter_gradients()
# print(self.pp_rank, "begin all reduce", torch.cuda.max_memory_allocated(ppg.get_local_pp_rank()), torch.cuda.max_memory_reserved(ppg.get_local_pp_rank()))
if step_index == 1:
ppg.chimera_step_lock.acquire()
print(f'rank_{self.pp_rank} before all reduce')
dist.all_reduce_coalesced(grads, group=all_reduce_group, async_op=False)
print(f'rank_{self.pp_rank} after all reduce')
if step_index == 0:
ppg.chimera_step_lock.release()
class ChimeraPipelineEngine(PipelineEngineBase): class ChimeraPipelineEngine(PipelineEngineBase):
def __init__(self, def __init__(self,
module_partitions, partition_fn: Callable,
stage_num, stage_num: int,
num_microbatches, num_microbatches: int,
device: str, device: str,
criterion: Callable = None, criterion: Callable = None,
metric: Callable = None, metric: Callable = None,
@ -189,11 +250,12 @@ class ChimeraPipelineEngine(PipelineEngineBase):
"In Chimera, num_microbatches must be the multiply of stage_num!" "In Chimera, num_microbatches must be the multiply of stage_num!"
use_1F1B = False use_1F1B = False
chunk = 1 chunk = 1
super().__init__(ChimeraWorker, module_partitions, stage_num, num_microbatches, device, use_1F1B, chunk,
criterion, metric, checkpoint) super().__init__(ChimeraWorker, partition_fn, stage_num, num_microbatches, device, use_1F1B, chunk, criterion,
metric, checkpoint)
def _consume_constraint(self, microbatch_id: int, forward_only: bool, ret_future: Dict[PyRRef, List[Future]], def _consume_constraint(self, microbatch_id: int, forward_only: bool, ret_future: Dict[PyRRef, List[Future]],
input_worker_rrefs: List[PyRRef], output_worker_rrefs: List[PyRRef]): input_pp_ranks: List[PyRRef], output_pp_ranks: List[PyRRef]):
pass pass
def _create_pp_rank_to_rpc_worker_id(self) -> None: def _create_pp_rank_to_rpc_worker_id(self) -> None:
@ -254,7 +316,6 @@ class ChimeraPipelineEngine(PipelineEngineBase):
up_key = UniqueKey(up_last_microbatch_id, Phase.BACKWARD) up_key = UniqueKey(up_last_microbatch_id, Phase.BACKWARD)
down_key = UniqueKey(down_last_microbatch_id, Phase.BACKWARD) down_key = UniqueKey(down_last_microbatch_id, Phase.BACKWARD)
up_worker_rref.rpc_sync().get_output_by_key(up_key) up_worker_rref.rpc_sync().get_output_by_key(up_key)
down_worker_rref.rpc_sync().get_output_by_key(down_key) down_worker_rref.rpc_sync().get_output_by_key(down_key)

Binary file not shown.

View File

@ -8,8 +8,13 @@ import torch.multiprocessing as mp
import torch.distributed.rpc as rpc import torch.distributed.rpc as rpc
from torch.optim import SGD, Adam, RMSprop, Optimizer from torch.optim import SGD, Adam, RMSprop, Optimizer
from torch._C._distributed_rpc import _is_current_rpc_agent_set from torch._C._distributed_rpc import _is_current_rpc_agent_set
import torch.distributed as dist
from colorama import Back, Style from colorama import Back, Style
from colossalai.pipeline.pipeline_process_group import ppg
from colossalai.logging import disable_existing_loggers
from colossalai import launch
rpc_is_initialized = _is_current_rpc_agent_set rpc_is_initialized = _is_current_rpc_agent_set
@ -25,12 +30,15 @@ class RpcTestModel(nn.Module):
self.rank = stage_id self.rank = stage_id
self.is_last_rank = stage_id == actual_stage_num - 1 self.is_last_rank = stage_id == actual_stage_num - 1
self.linear_name = f'linear_{stage_id}' self.linear_name = f'linear_{stage_id}'
if stage_id == 0: if stage_id == 0:
setattr(self, self.linear_name, nn.Linear(feat_num, h)) linear = nn.Linear(feat_num, h)
elif stage_id == actual_stage_num - 1: elif stage_id == actual_stage_num - 1:
setattr(self, self.linear_name, nn.Linear(h, 1)) linear = nn.Linear(h, 1)
else: else:
setattr(self, self.linear_name, nn.Linear(h, h)) linear = nn.Linear(h, h)
setattr(self, self.linear_name, linear)
def forward(self, x) -> torch.Tensor: def forward(self, x) -> torch.Tensor:
linear: nn.Module = getattr(self, self.linear_name) linear: nn.Module = getattr(self, self.linear_name)
@ -46,6 +54,8 @@ def parse_args():
parser.add_argument('--epoch', type=int, default=1) parser.add_argument('--epoch', type=int, default=1)
parser.add_argument('--world_size', type=int, default=2) parser.add_argument('--world_size', type=int, default=2)
parser.add_argument('--batch_size', type=int, default=16) parser.add_argument('--batch_size', type=int, default=16)
parser.add_argument('--dp_degree', type=int, default=1)
parser.add_argument('--tp_degree', type=int, default=1)
parser.add_argument('--num_microbatches', type=int, default=2) parser.add_argument('--num_microbatches', type=int, default=2)
parser.add_argument('--chunk', type=int, default=1) parser.add_argument('--chunk', type=int, default=1)
parser.add_argument('--use_checkpoint', action='store_true') parser.add_argument('--use_checkpoint', action='store_true')
@ -74,16 +84,24 @@ def run_worker(rank, args, master_func):
os.environ['MASTER_ADDR'] = args.master_addr os.environ['MASTER_ADDR'] = args.master_addr
os.environ['MASTER_PORT'] = args.master_port os.environ['MASTER_PORT'] = args.master_port
# config rpc device = args.device
# if cuda is used, set_device_map is a must is configured
# for cuda is not supported in torch rpc by default
options = rpc.TensorPipeRpcBackendOptions(num_worker_threads=args.num_worker_threads)
world_size = args.world_size world_size = args.world_size
for rank_idx in range(world_size): dp_degree = args.dp_degree
options.set_device_map(f'work{rank_idx}', {rank: rank_idx}) tp_degree = args.tp_degree
num_worker_threads = args.num_worker_threads
host = args.master_addr
port = args.master_port
backend = 'nccl' if device == 'cuda' else 'gloo'
rpc.init_rpc(name=f'work{rank}', rank=rank, world_size=world_size, rpc_backend_options=options) disable_existing_loggers()
launch(dict(), rank, world_size, host, int(port), backend, verbose=False)
ppg.set_global_info(rank=rank,
world_size=world_size,
dp_degree=dp_degree,
tp_degree=tp_degree,
num_worker_threads=num_worker_threads,
device=device)
# in rpc mode, only rank 0 is needed to be coded # in rpc mode, only rank 0 is needed to be coded
if rank == 0: if rank == 0:

View File

@ -1,9 +1,21 @@
import torch import torch
from torch import nn from torch import nn
import torch.autograd as autograd
from colossalai.pipeline.rpc._pipeline_schedule import FillDrainPipelineEngine, OneFOneBPipelineEngine, ChimeraPipelineEngine from colossalai.pipeline.rpc import ChimeraPipelineEngine
from colossalai.testing import assert_close
from rpc_test_utils import rpc_run, parse_args, RpcTestModel from rpc_test_utils import rpc_run, parse_args, RpcTestModel
# global variable for model created
feat_num = 100
h = 100
def partition(pp_rank: int, chunk: int, stage_num: int):
torch.manual_seed(1024)
partition = RpcTestModel(pp_rank, stage_num, feat_num, h)
return partition
def run_master(args): def run_master(args):
torch.manual_seed(100) torch.manual_seed(100)
@ -17,23 +29,51 @@ def run_master(args):
use_checkpoint = False use_checkpoint = False
sample_num = 1024 sample_num = 1024
feat_num = 10
h = 10
batch_size = 1024 batch_size = 1024
assert sample_num % batch_size == 0 assert sample_num % batch_size == 0
module_partitions = [RpcTestModel(pp_rank, actual_stage_num, feat_num, h) for pp_rank in range(actual_stage_num)] engine = ChimeraPipelineEngine(partition_fn=partition,
engine = ChimeraPipelineEngine(module_partitions=module_partitions,
stage_num=stage_num, stage_num=stage_num,
num_microbatches=num_microbatches, num_microbatches=num_microbatches,
device=device, device=device,
checkpoint=use_checkpoint) checkpoint=use_checkpoint)
engine.initialize_optimizer(torch.optim.Adam, lr=1e-3)
input_sample = torch.randn((sample_num, feat_num), device=device) input_sample = torch.randn((sample_num, feat_num), device=device)
for _ in range(epoch): forward_result = engine.forward_backward(input_sample)
_ = engine.forward_backward(input_sample, forward_only=False)
cuda_rpc_result = []
single_result = []
actual_stage_num = engine._get_actual_stage_num()
# compute forward result and backward grad of parameters in cuda rpc
cuda_rpc_result.append(sum(forward_result[0]))
grad = engine.remote_grad()
for stage_id in range(actual_stage_num):
for p in grad[stage_id]:
cuda_rpc_result.append(p)
# compute forward result and backward grad of parameters just in rank_0
test_model = nn.Sequential(
*[partition(pp_rank, chunk, actual_stage_num) for pp_rank in range(actual_stage_num)]).to(device)
# input_sample = input_sample[len(input_sample) // 2:]
input_sample = input_sample.requires_grad_()
out_val = test_model(input_sample).sum()
autograd.backward(out_val)
single_result.append(out_val)
for p in test_model.parameters():
single_result.append(p.grad)
# print("my")
# print(cuda_rpc_result[1])
# print("answer:")
# print(single_result[1])
# assert len(cuda_rpc_result) == len(single_result)
# for r_c, r_s in zip(cuda_rpc_result, single_result):
# assert_close(r_c, r_s, 0.001, 0.001)
if __name__ == "__main__": if __name__ == "__main__":

View File

@ -7,6 +7,16 @@ from colossalai.pipeline.rpc._pipeline_schedule import FillDrainPipelineEngine,
from colossalai.testing import assert_close from colossalai.testing import assert_close
from rpc_test_utils import rpc_run, parse_args, RpcTestModel from rpc_test_utils import rpc_run, parse_args, RpcTestModel
# global variable for model created
feat_num = 100
h = 100
def partition(pp_rank: int, chunk: int, stage_num: int):
torch.manual_seed(1024)
partition = RpcTestModel(pp_rank, stage_num, feat_num, h)
return partition
def run_master(args): def run_master(args):
torch.manual_seed(100) torch.manual_seed(100)
@ -20,20 +30,14 @@ def run_master(args):
optimizer_class = globals()[args.optimizer] optimizer_class = globals()[args.optimizer]
lr = 1e-3 lr = 1e-3
sample_num = 1024 sample_num = 1024
feat_num = 100
h = 100
batch_size = 1024 batch_size = 1024
assert sample_num % batch_size == 0 assert sample_num % batch_size == 0
batch_num = sample_num // batch_size
input_sample = torch.randn((sample_num, feat_num), device=device) input_sample = torch.randn((sample_num, feat_num), device=device)
module_partitions = [RpcTestModel(pp_rank, actual_stage_num, feat_num, h) for pp_rank in range(actual_stage_num)] engine = OneFOneBPipelineEngine(partition_fn=partition,
engine = OneFOneBPipelineEngine(module_partitions=module_partitions,
stage_num=stage_num, stage_num=stage_num,
num_microbatches=num_microbatches, num_microbatches=num_microbatches,
device=device, device=device,
@ -55,7 +59,8 @@ def run_master(args):
cuda_rpc_result.append(p) cuda_rpc_result.append(p)
# compute forward result and backward grad of parameters just in rank_0 # compute forward result and backward grad of parameters just in rank_0
test_model = nn.Sequential(*module_partitions).to(device) test_model = nn.Sequential(
*[partition(pp_rank, chunk, actual_stage_num) for pp_rank in range(actual_stage_num)]).to(device)
optimizer: Optimizer = optimizer_class(test_model.parameters(), lr=lr) optimizer: Optimizer = optimizer_class(test_model.parameters(), lr=lr)
input_sample = input_sample.requires_grad_() input_sample = input_sample.requires_grad_()
out_val = test_model(input_sample).sum() out_val = test_model(input_sample).sum()

View File

@ -18,17 +18,30 @@ from colossalai.trainer import Trainer, hooks
from colossalai.utils import MultiTimer, get_dataloader from colossalai.utils import MultiTimer, get_dataloader
from colossalai.context import ParallelMode from colossalai.context import ParallelMode
from colossalai.pipeline.pipelinable import PipelinableContext, PipelinableModel from colossalai.pipeline.pipelinable import PipelinableContext, PipelinableModel
from colossalai.pipeline.rpc._pipeline_schedule import OneFOneBPipelineEngine from colossalai.pipeline.rpc import OneFOneBPipelineEngine, ChimeraPipelineEngine
from colossalai.pipeline.pipeline_process_group import ppg
def flatten(x): def flatten(x):
return torch.flatten(x, 1) return torch.flatten(x, 1)
class Flatten(nn.Module): def partition(pp_rank: int, chunk: int, stage_num: int):
pipelinable = PipelinableContext()
def forward(self, x): # build model partitions
return torch.flatten(x, start_dim=1) with pipelinable:
# input : [B, 3, 32, 32]
_ = resnet50()
pipelinable.policy = "customized"
exec_seq = [
'conv1', 'bn1', 'relu', 'maxpool', 'layer1', 'layer2', 'layer3', 'layer4', 'avgpool', (flatten, "behind"), 'fc'
]
pipelinable.to_layer_list(exec_seq)
partition = pipelinable.partition(chunk, stage_num, pp_rank)
return partition
def run_master(args): def run_master(args):
@ -39,37 +52,12 @@ def run_master(args):
stage_num = world_size stage_num = world_size
num_microbatches = args.num_microbatches num_microbatches = args.num_microbatches
assert chunk == 1
pipelinable = PipelinableContext()
# build model partitions
with pipelinable:
# input : [B, 3, 32, 32]
model = resnet50()
exec_seq = [
'conv1', 'bn1', 'relu', 'maxpool', 'layer1', 'layer2', 'layer3', 'layer4', 'avgpool', (flatten, "behind"), 'fc'
]
pipelinable.to_layer_list(exec_seq)
module_partitions: List[PipelinableModel] = [
pipelinable.partition(chunk, stage_num, pp_rank) for pp_rank in range(world_size)
]
# build dataloader # build dataloader
root = os.environ.get('DATA', './data') root = os.environ.get('DATA', './data')
train_dataloader, test_dataloader = build_cifar(batch_size, root, padding=4, crop=32, resize=32) train_dataloader, test_dataloader = build_cifar(batch_size, root, padding=4, crop=32, resize=32)
criterion = nn.CrossEntropyLoss() criterion = nn.CrossEntropyLoss()
partition_1 = module_partitions[0] pp_engine = OneFOneBPipelineEngine(partition_fn=partition,
partition_2 = []
for model in module_partitions[1]._module_list:
partition_2.append(model)
partition_2.insert(len(partition_2) - 1, Flatten())
partition_2 = nn.Sequential(*partition_2)
module_partitions = [partition_1, partition_2]
pp_engine = OneFOneBPipelineEngine(module_partitions=module_partitions,
stage_num=stage_num, stage_num=stage_num,
num_microbatches=num_microbatches, num_microbatches=num_microbatches,
device=device, device=device,

View File

@ -4,6 +4,16 @@ from torch import nn
from colossalai.pipeline.rpc._pipeline_schedule import FillDrainPipelineEngine, OneFOneBPipelineEngine from colossalai.pipeline.rpc._pipeline_schedule import FillDrainPipelineEngine, OneFOneBPipelineEngine
from rpc_test_utils import rpc_run, parse_args, RpcTestModel from rpc_test_utils import rpc_run, parse_args, RpcTestModel
# global variable for model created
feat_num = 100
h = 100
def partition(pp_rank: int, chunk: int, stage_num: int):
torch.manual_seed(1024)
partition = RpcTestModel(pp_rank, stage_num, feat_num, h)
return partition
def run_master(args): def run_master(args):
torch.manual_seed(100) torch.manual_seed(100)
@ -13,22 +23,16 @@ def run_master(args):
stage_num = args.world_size stage_num = args.world_size
chunk = args.chunk chunk = args.chunk
num_microbatches = args.num_microbatches num_microbatches = args.num_microbatches
actual_stage_num = stage_num * chunk
use_checkpoint = args.use_checkpoint use_checkpoint = args.use_checkpoint
sample_num = 1024 sample_num = 1024
feat_num = 10
h = 10
batch_size = 1024 batch_size = 1024
assert sample_num % batch_size == 0 assert sample_num % batch_size == 0
batch_num = sample_num // batch_size
input_sample = torch.randn((sample_num, feat_num), device=device) input_sample = torch.randn((sample_num, feat_num), device=device)
module_partitions = [RpcTestModel(pp_rank, actual_stage_num, feat_num, h) for pp_rank in range(actual_stage_num)] engine = OneFOneBPipelineEngine(partition_fn=partition,
engine = OneFOneBPipelineEngine(module_partitions=module_partitions,
stage_num=stage_num, stage_num=stage_num,
num_microbatches=num_microbatches, num_microbatches=num_microbatches,
device=device, device=device,

View File

@ -6,6 +6,15 @@ from colossalai.pipeline.rpc._pipeline_schedule import FillDrainPipelineEngine,
from colossalai.testing import assert_close from colossalai.testing import assert_close
from rpc_test_utils import rpc_run, parse_args, RpcTestModel from rpc_test_utils import rpc_run, parse_args, RpcTestModel
feat_num = 100
h = 100
def partition(pp_rank: int, chunk: int, stage_num: int):
torch.manual_seed(1024)
partition = RpcTestModel(pp_rank, stage_num, feat_num, h)
return partition
def run_master(args): def run_master(args):
torch.manual_seed(100) torch.manual_seed(100)
@ -18,25 +27,20 @@ def run_master(args):
num_microbatches = args.num_microbatches num_microbatches = args.num_microbatches
sample_num = 1024 sample_num = 1024
feat_num = 100
h = 100
batch_size = 1024 batch_size = 1024
assert sample_num % batch_size == 0 assert sample_num % batch_size == 0
batch_num = sample_num // batch_size
input_sample = torch.randn((sample_num, feat_num), device=device) input_sample = torch.randn((sample_num, feat_num), device=device)
module_partitions = [RpcTestModel(pp_rank, actual_stage_num, feat_num, h) for pp_rank in range(actual_stage_num)] engine = OneFOneBPipelineEngine(partition_fn=partition,
engine = OneFOneBPipelineEngine(module_partitions=module_partitions,
stage_num=stage_num, stage_num=stage_num,
num_microbatches=num_microbatches, num_microbatches=num_microbatches,
device=device, device=device,
chunk=chunk, chunk=chunk,
checkpoint=use_checkpoint) checkpoint=use_checkpoint)
forward_result = engine.forward_backward(input_sample)[0] forward_result = engine.forward_backward(input_sample)
cuda_rpc_result = [] cuda_rpc_result = []
single_result = [] single_result = []
@ -50,7 +54,8 @@ def run_master(args):
cuda_rpc_result.append(p) cuda_rpc_result.append(p)
# compute forward result and backward grad of parameters just in rank_0 # compute forward result and backward grad of parameters just in rank_0
test_model = nn.Sequential(*module_partitions).to(device) test_model = nn.Sequential(
*[partition(pp_rank, chunk, actual_stage_num) for pp_rank in range(actual_stage_num)]).to(device)
input_sample = input_sample.requires_grad_() input_sample = input_sample.requires_grad_()
out_val = test_model(input_sample).sum() out_val = test_model(input_sample).sum()
autograd.backward(out_val) autograd.backward(out_val)

View File

@ -4,7 +4,7 @@ import torch.distributed.rpc as rpc
import torch.multiprocessing as mp import torch.multiprocessing as mp
import pytest import pytest
from colossalai.pipeline.pipeline_process_group import PipelineProcessGroup from colossalai.pipeline.pipeline_process_group import ppg
from colossalai.initialize import launch from colossalai.initialize import launch
from colossalai.logging import disable_existing_loggers from colossalai.logging import disable_existing_loggers
from rpc_test_utils import pg_parse_args, rpc_is_initialized from rpc_test_utils import pg_parse_args, rpc_is_initialized
@ -26,7 +26,7 @@ def run_worker(rank, args):
disable_existing_loggers() disable_existing_loggers()
launch(dict(), rank, world_size, host, int(port), backend, verbose=False) launch(dict(), rank, world_size, host, int(port), backend, verbose=False)
pg = PipelineProcessGroup(rank=rank, ppg.set_global_info(rank=rank,
world_size=world_size, world_size=world_size,
dp_degree=dp_degree, dp_degree=dp_degree,
tp_degree=tp_degree, tp_degree=tp_degree,