mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-08-23 02:06:35 +00:00
[pipeline/chimera] test chimera | fix bug of initializing (#1615)
* [pipeline/tuning] improve dispatch performance both time and space cost * [pipeline/converge] add interface for testing convergence * [NFC] polish colossalai/utils/multi_tensor_apply/multi_tensor_apply.py code style * Update PipelineBase.py * [pipeline/chimera] reconstruct PipelineBase and Worker to support more feasible custom schedule | finish Chimera * [pipeline/chimera] test chimera | fix bug of initializing
This commit is contained in:
parent
504ff1d101
commit
170fa81095
@ -110,6 +110,7 @@ class PipelinableContext(InsertPostInitMethodToModuleSubClasses):
|
|||||||
name_list.append((name, param))
|
name_list.append((name, param))
|
||||||
|
|
||||||
for name, param in name_list:
|
for name, param in name_list:
|
||||||
|
if hasattr(module, name):
|
||||||
delattr(module, name)
|
delattr(module, name)
|
||||||
setattr(module, name, ColoParameter.from_torch_tensor(tensor=param.data, requires_grad=param.requires_grad))
|
setattr(module, name, ColoParameter.from_torch_tensor(tensor=param.data, requires_grad=param.requires_grad))
|
||||||
|
|
||||||
|
@ -1,5 +1,6 @@
|
|||||||
from typing import List, Dict, Tuple
|
from typing import List, Dict, Tuple
|
||||||
import os
|
import os
|
||||||
|
import threading
|
||||||
|
|
||||||
from torch.distributed import rpc
|
from torch.distributed import rpc
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
@ -10,13 +11,17 @@ from colossalai.tensor import ProcessGroup
|
|||||||
class PipelineProcessGroup:
|
class PipelineProcessGroup:
|
||||||
# TODO : flexible API for DP size and TP size
|
# TODO : flexible API for DP size and TP size
|
||||||
# In the future design mode, dp_degree and tp_degree should be removed
|
# In the future design mode, dp_degree and tp_degree should be removed
|
||||||
def __init__(self,
|
def __init__(self) -> None:
|
||||||
|
self.is_initialize = False
|
||||||
|
|
||||||
|
def set_global_info(self,
|
||||||
rank: int,
|
rank: int,
|
||||||
world_size: int,
|
world_size: int,
|
||||||
dp_degree: int = 1,
|
dp_degree: int = 1,
|
||||||
tp_degree: int = 1,
|
tp_degree: int = 1,
|
||||||
num_worker_threads: int = 1,
|
num_worker_threads: int = 1,
|
||||||
device: str = "cuda") -> None:
|
device: str = "cuda") -> None:
|
||||||
|
|
||||||
device_mesh_size = dp_degree * tp_degree
|
device_mesh_size = dp_degree * tp_degree
|
||||||
assert world_size % device_mesh_size == 0, "world_size must be the multiple of dp_degree * tp_degree !!!"
|
assert world_size % device_mesh_size == 0, "world_size must be the multiple of dp_degree * tp_degree !!!"
|
||||||
self._num_worker_threads = num_worker_threads
|
self._num_worker_threads = num_worker_threads
|
||||||
@ -42,6 +47,11 @@ class PipelineProcessGroup:
|
|||||||
self._is_first_pp_rank = self._pp_rank == 0
|
self._is_first_pp_rank = self._pp_rank == 0
|
||||||
self._is_last_pp_rank = self._pp_rank == self._stage_num - 1
|
self._is_last_pp_rank = self._pp_rank == self._stage_num - 1
|
||||||
|
|
||||||
|
self.is_initialize = True
|
||||||
|
|
||||||
|
# lock
|
||||||
|
self.chimera_lock = threading.Lock()
|
||||||
|
|
||||||
def _initialize_process_group(self):
|
def _initialize_process_group(self):
|
||||||
stage_num = self.get_stage_num()
|
stage_num = self.get_stage_num()
|
||||||
if stage_num == 1:
|
if stage_num == 1:
|
||||||
@ -133,3 +143,25 @@ class PipelineProcessGroup:
|
|||||||
|
|
||||||
def get_tp_global_ranks(self):
|
def get_tp_global_ranks(self):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
def get_chimera_all_reduce_group(self, pp_rank: int):
|
||||||
|
with self.chimera_lock:
|
||||||
|
if not hasattr(self, 'chimera_groups'):
|
||||||
|
world_size = self.get_world_size()
|
||||||
|
stage_num = self.get_stage_num()
|
||||||
|
assert world_size % 2 == 0, 'world_size must be even in chimera!'
|
||||||
|
self.chimera_groups = {}
|
||||||
|
for rank in range(world_size // 2):
|
||||||
|
pair = [rank, world_size - 1 - rank]
|
||||||
|
group = dist.new_group(pair)
|
||||||
|
self.chimera_groups[pair[0]] = group
|
||||||
|
self.chimera_groups[pair[1]] = group
|
||||||
|
self.chimera_groups[pair[0] + stage_num] = group
|
||||||
|
self.chimera_groups[pair[1] + stage_num] = group
|
||||||
|
self.chimera_step_lock = threading.Lock()
|
||||||
|
self.chimera_step_lock.acquire()
|
||||||
|
|
||||||
|
return self.chimera_groups[pp_rank]
|
||||||
|
|
||||||
|
|
||||||
|
ppg = PipelineProcessGroup()
|
||||||
|
3
colossalai/pipeline/rpc/__init__.py
Normal file
3
colossalai/pipeline/rpc/__init__.py
Normal file
@ -0,0 +1,3 @@
|
|||||||
|
from ._pipeline_schedule import FillDrainPipelineEngine, OneFOneBPipelineEngine, ChimeraPipelineEngine
|
||||||
|
|
||||||
|
__all__ = ['FillDrainPipelineEngine', 'OneFOneBPipelineEngine', 'ChimeraPipelineEngine']
|
@ -139,7 +139,8 @@ class BackwardCache:
|
|||||||
class WorkerBase(ABC):
|
class WorkerBase(ABC):
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
module_partition: nn.Module,
|
partition_fn: Callable,
|
||||||
|
partition_args: tuple,
|
||||||
pp_rank: int,
|
pp_rank: int,
|
||||||
actual_stage_num: int,
|
actual_stage_num: int,
|
||||||
num_microbatches: int,
|
num_microbatches: int,
|
||||||
@ -165,21 +166,22 @@ class WorkerBase(ABC):
|
|||||||
# rref of other workers
|
# rref of other workers
|
||||||
self.pp_rank_to_worker_rref: Dict[int, PyRRef] = None
|
self.pp_rank_to_worker_rref: Dict[int, PyRRef] = None
|
||||||
|
|
||||||
|
# lock for the list
|
||||||
|
self._initialize_lock()
|
||||||
|
|
||||||
# topology info
|
# topology info
|
||||||
self.producer_stage_ids: List[int] = None
|
self.producer_stage_ids: List[int] = None
|
||||||
self.consumer_stage_ids: List[int] = None
|
self.consumer_stage_ids: List[int] = None
|
||||||
|
|
||||||
# module partitions
|
# module partitions
|
||||||
self.module_partition = module_partition.to(device)
|
self.partition_fn = partition_fn
|
||||||
|
self.partition_args = partition_args
|
||||||
self.criterion = criterion
|
self.criterion = criterion
|
||||||
self.metric = metric
|
self.metric = metric
|
||||||
|
|
||||||
# context to maintain loop
|
# context to maintain loop
|
||||||
self._initialize_context_container()
|
self._initialize_context_container()
|
||||||
|
|
||||||
# lock for the list
|
|
||||||
self._initialize_lock()
|
|
||||||
|
|
||||||
# main loop
|
# main loop
|
||||||
self.main_loop_thread = threading.Thread(target=self._work_loop, name=f'rank_{pp_rank}', daemon=True)
|
self.main_loop_thread = threading.Thread(target=self._work_loop, name=f'rank_{pp_rank}', daemon=True)
|
||||||
self.main_loop_thread.start()
|
self.main_loop_thread.start()
|
||||||
@ -202,20 +204,37 @@ class WorkerBase(ABC):
|
|||||||
self.output_list: Dict[UniqueKey, WorkItem] = dict()
|
self.output_list: Dict[UniqueKey, WorkItem] = dict()
|
||||||
|
|
||||||
def _initialize_lock(self):
|
def _initialize_lock(self):
|
||||||
|
self.partition_condition_lock = threading.Condition(threading.Lock())
|
||||||
self.work_list_condition_lock = threading.Condition(threading.Lock())
|
self.work_list_condition_lock = threading.Condition(threading.Lock())
|
||||||
self.output_list_condition_lock = threading.Condition(threading.Lock())
|
self.output_list_condition_lock = threading.Condition(threading.Lock())
|
||||||
self.label_lock = threading.Condition(threading.Lock())
|
self.label_lock = threading.Condition(threading.Lock())
|
||||||
|
|
||||||
|
def _initialize_partition(self):
|
||||||
|
partition_fn = self.partition_fn
|
||||||
|
partition_args = self.partition_args
|
||||||
|
device = self.device
|
||||||
|
with self.partition_condition_lock:
|
||||||
|
self.module_partition: nn.Module = partition_fn(*partition_args).to(device)
|
||||||
|
self.partition_condition_lock.notify_all()
|
||||||
|
|
||||||
def sync_global_worker_rrefs(self, pp_rank_to_worker_rref: Dict[int, PyRRef]) -> None:
|
def sync_global_worker_rrefs(self, pp_rank_to_worker_rref: Dict[int, PyRRef]) -> None:
|
||||||
assert self.pp_rank_to_worker_rref is None, f"in rank {self.pp_rank}, worker has sync global workers rrefs"
|
assert self.pp_rank_to_worker_rref is None, f"in rank {self.pp_rank}, worker has sync global workers rrefs"
|
||||||
assert pp_rank_to_worker_rref is not None, "stage_to_workers must be a dict instead of None"
|
assert pp_rank_to_worker_rref is not None, "stage_to_workers must be a dict instead of None"
|
||||||
self.pp_rank_to_worker_rref = pp_rank_to_worker_rref
|
self.pp_rank_to_worker_rref = pp_rank_to_worker_rref
|
||||||
|
|
||||||
|
# for some schedule need the other worker's info to initialise partition (like Chimera)
|
||||||
|
# construction of partition is executed after the registion of pp_rank_to_worker_rref
|
||||||
|
self._initialize_partition()
|
||||||
|
|
||||||
def get_output_by_key(self, key: UniqueKey) -> Any:
|
def get_output_by_key(self, key: UniqueKey) -> Any:
|
||||||
with self.output_list_condition_lock:
|
with self.output_list_condition_lock:
|
||||||
self.output_list_condition_lock.wait_for(lambda: key in self.output_list)
|
self.output_list_condition_lock.wait_for(lambda: key in self.output_list)
|
||||||
output_work_item = self.output_list[key]
|
output_work_item = self.output_list[key]
|
||||||
output = output_work_item.output.wait()
|
|
||||||
|
output = output_work_item.output
|
||||||
|
if isinstance(output, Future):
|
||||||
|
output = output.wait()
|
||||||
|
|
||||||
# color_debug(f'rank {self.pp_rank}, output {type(output)}', 'get output', 'red')
|
# color_debug(f'rank {self.pp_rank}, output {type(output)}', 'get output', 'red')
|
||||||
output_work_item.refcount += 1
|
output_work_item.refcount += 1
|
||||||
|
|
||||||
@ -231,6 +250,16 @@ class WorkerBase(ABC):
|
|||||||
def get_parameter_gradients(self) -> List[torch.Tensor]:
|
def get_parameter_gradients(self) -> List[torch.Tensor]:
|
||||||
return [p.grad for p in self.module_partition.parameters()]
|
return [p.grad for p in self.module_partition.parameters()]
|
||||||
|
|
||||||
|
def get_partition(self):
|
||||||
|
with self.partition_condition_lock:
|
||||||
|
self.partition_condition_lock.wait_for(lambda: hasattr(self, 'module_partition'))
|
||||||
|
return self.module_partition
|
||||||
|
|
||||||
|
def get_partition_state_dict(self):
|
||||||
|
with self.partition_condition_lock:
|
||||||
|
self.partition_condition_lock.wait_for(lambda: hasattr(self, 'module_partition'))
|
||||||
|
return self.module_partition.state_dict()
|
||||||
|
|
||||||
# just for first pp_rank
|
# just for first pp_rank
|
||||||
def set_input(self, microbatch_id: int, microbatch: Tuple[Any], forward_only: bool):
|
def set_input(self, microbatch_id: int, microbatch: Tuple[Any], forward_only: bool):
|
||||||
assert self.consumer_stage_ids is not None
|
assert self.consumer_stage_ids is not None
|
||||||
@ -520,6 +549,15 @@ class WorkerBase(ABC):
|
|||||||
is_last_microbatch = work_item.microbatch_id == self.num_microbatches - 1
|
is_last_microbatch = work_item.microbatch_id == self.num_microbatches - 1
|
||||||
return is_last_phase and is_last_microbatch
|
return is_last_phase and is_last_microbatch
|
||||||
|
|
||||||
|
def _hook_before_step(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def _reset_context(self):
|
||||||
|
self.forward_times = 0
|
||||||
|
self.backward_times = 0
|
||||||
|
self.outstanding = 0
|
||||||
|
self._initialize_outstanding_range()
|
||||||
|
|
||||||
# do the main loop to consume ready_list
|
# do the main loop to consume ready_list
|
||||||
def _work_loop(self):
|
def _work_loop(self):
|
||||||
# for init
|
# for init
|
||||||
@ -545,19 +583,17 @@ class WorkerBase(ABC):
|
|||||||
consume_result = self._consume_work_item_by_phase(work_item)
|
consume_result = self._consume_work_item_by_phase(work_item)
|
||||||
|
|
||||||
color_debug(
|
color_debug(
|
||||||
f'rank_{self.pp_rank} [{work_item.phase}] finish consuming, result is {tensor_shape_list(consume_result)} {self._get_store_len()}',
|
f'rank_{self.pp_rank} [{work_item.phase}] finish consuming, result is {tensor_shape_list(consume_result)} {self._get_store_len()} | {self.work_list.keys()} | {self.output_list.keys()}',
|
||||||
'work loop', 'green')
|
'work loop', 'green')
|
||||||
|
|
||||||
work_item.output.set_result(consume_result)
|
work_item.output.set_result(consume_result)
|
||||||
|
|
||||||
# if is last step in one batch reset context and do step
|
# if is last step in one batch reset context and do step
|
||||||
if self._is_last_step(work_item):
|
if self._is_last_step(work_item):
|
||||||
|
self._hook_before_step()
|
||||||
if hasattr(self, 'optimizer') and not work_item.forward_only:
|
if hasattr(self, 'optimizer') and not work_item.forward_only:
|
||||||
self.step()
|
self.step()
|
||||||
self.forward_times = 0
|
self._reset_context()
|
||||||
self.backward_times = 0
|
|
||||||
self.outstanding = 0
|
|
||||||
self._initialize_outstanding_range()
|
|
||||||
|
|
||||||
def initialize_optimizer(self, optimizer_class: type, **kwargs):
|
def initialize_optimizer(self, optimizer_class: type, **kwargs):
|
||||||
self.optimizer: optim.Optimizer = optimizer_class(self.module_partition.parameters(), **kwargs)
|
self.optimizer: optim.Optimizer = optimizer_class(self.module_partition.parameters(), **kwargs)
|
||||||
@ -577,7 +613,7 @@ class PipelineEngineBase(ABC, nn.Module):
|
|||||||
|
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
worker_type,
|
worker_type,
|
||||||
module_partitions,
|
partition_fn: Callable,
|
||||||
stage_num,
|
stage_num,
|
||||||
num_microbatches,
|
num_microbatches,
|
||||||
device: str,
|
device: str,
|
||||||
@ -588,7 +624,7 @@ class PipelineEngineBase(ABC, nn.Module):
|
|||||||
checkpoint: bool = False) -> None:
|
checkpoint: bool = False) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.worker_type = worker_type
|
self.worker_type = worker_type
|
||||||
self.module_partitions: List[nn.Module] = module_partitions
|
self.partition_fn: Callable = partition_fn
|
||||||
self.chunk = chunk
|
self.chunk = chunk
|
||||||
self.criterion = criterion
|
self.criterion = criterion
|
||||||
self.metric = metric
|
self.metric = metric
|
||||||
@ -609,18 +645,15 @@ class PipelineEngineBase(ABC, nn.Module):
|
|||||||
|
|
||||||
def _check_argument(self) -> None:
|
def _check_argument(self) -> None:
|
||||||
self.virtual_stage_num = self.stage_num * self.chunk
|
self.virtual_stage_num = self.stage_num * self.chunk
|
||||||
|
|
||||||
assert self.stage_num <= torch.cuda.device_count(), "stage_num must be smaller than device count!"
|
assert self.stage_num <= torch.cuda.device_count(), "stage_num must be smaller than device count!"
|
||||||
assert self.virtual_stage_num == len(
|
|
||||||
self.module_partitions), "stage_num * chunk must be equal to length of model partition!"
|
|
||||||
|
|
||||||
def _get_actual_stage_num(self) -> int:
|
def _get_actual_stage_num(self) -> int:
|
||||||
return self.stage_num if self.chunk == 1 else self.virtual_stage_num
|
return self.stage_num if self.chunk == 1 else self.virtual_stage_num
|
||||||
|
|
||||||
def _create_pp_rank_to_rpc_worker_id(self) -> None:
|
def _create_pp_rank_to_rpc_worker_id(self) -> None:
|
||||||
"""create a map from model partition to stage_id, which is useful when use_interleave is True.
|
"""create a map from model partition to stage_id, which is useful when use_interleave is True.
|
||||||
e.g. If a model is splited into 4 parts, which means len(self.module_partitions) == 3.
|
e.g. If a model is splited into 4 parts, which means stage_num is 2, chunk is 2, then
|
||||||
stage_num is 2, chunk is 2, then pp_rank_to_rpc_worker_id = [0, 1, 0, 1], that means first and third part
|
pp_rank_to_rpc_worker_id = [0, 1, 0, 1], that means first and third part
|
||||||
of partitions will be moved to device 0 and the others to device 1
|
of partitions will be moved to device 0 and the others to device 1
|
||||||
"""
|
"""
|
||||||
stage_num = self.stage_num
|
stage_num = self.stage_num
|
||||||
@ -647,26 +680,34 @@ class PipelineEngineBase(ABC, nn.Module):
|
|||||||
device = self.device
|
device = self.device
|
||||||
criterion = self.criterion
|
criterion = self.criterion
|
||||||
metric = self.metric
|
metric = self.metric
|
||||||
|
partition_fn = self.partition_fn
|
||||||
|
chunk = self.chunk
|
||||||
|
|
||||||
for pp_rank in range(len(self.pp_rank_to_rpc_worker_id)):
|
for pp_rank in range(len(self.pp_rank_to_rpc_worker_id)):
|
||||||
module_partition_id = self.pp_rank_to_module_partition_id[pp_rank]
|
partition_id = self.pp_rank_to_module_partition_id[pp_rank]
|
||||||
|
partition_args = (partition_id, chunk, actual_stage_num)
|
||||||
rpc_worker_id = self.pp_rank_to_rpc_worker_id[pp_rank]
|
rpc_worker_id = self.pp_rank_to_rpc_worker_id[pp_rank]
|
||||||
if device[:4] == 'cuda':
|
if device[:4] == 'cuda':
|
||||||
device = f'cuda:{rpc_worker_id}'
|
device = f'cuda:{rpc_worker_id}'
|
||||||
module_partition = self.module_partitions[module_partition_id]
|
|
||||||
self.pp_rank_to_worker_rref[pp_rank] = rpc.remote(rpc_worker_id,
|
self.pp_rank_to_worker_rref[pp_rank] = rpc.remote(rpc_worker_id,
|
||||||
worker_type,
|
worker_type,
|
||||||
args=(module_partition, pp_rank, actual_stage_num,
|
args=(partition_fn, partition_args, pp_rank,
|
||||||
num_microbatches, device, criterion, metric,
|
actual_stage_num, num_microbatches, device,
|
||||||
checkpoint))
|
criterion, metric, checkpoint))
|
||||||
|
|
||||||
# let each worker know global worker rref (include itself)
|
# let each worker know global worker rref (include itself)
|
||||||
|
sync_futs = []
|
||||||
for pp_rank in self.pp_rank_to_worker_rref:
|
for pp_rank in self.pp_rank_to_worker_rref:
|
||||||
self.pp_rank_to_worker_rref[pp_rank].rpc_sync().sync_global_worker_rrefs(self.pp_rank_to_worker_rref)
|
fut = self.pp_rank_to_worker_rref[pp_rank].rpc_async().sync_global_worker_rrefs(self.pp_rank_to_worker_rref)
|
||||||
|
sync_futs.append(fut)
|
||||||
|
|
||||||
|
for fut in sync_futs:
|
||||||
|
fut.wait()
|
||||||
|
|
||||||
def remote_parameters(self) -> Dict[int, List[torch.Tensor]]:
|
def remote_parameters(self) -> Dict[int, List[torch.Tensor]]:
|
||||||
parameters = {}
|
parameters = {}
|
||||||
for stage_id in self.pp_rank_to_worker_rref:
|
actual_stage_num = self._get_actual_stage_num()
|
||||||
|
for stage_id in range(actual_stage_num):
|
||||||
parameters[stage_id] = []
|
parameters[stage_id] = []
|
||||||
worker_rref = self.pp_rank_to_worker_rref[stage_id]
|
worker_rref = self.pp_rank_to_worker_rref[stage_id]
|
||||||
for p in worker_rref.rpc_sync().get_parameters():
|
for p in worker_rref.rpc_sync().get_parameters():
|
||||||
@ -675,7 +716,8 @@ class PipelineEngineBase(ABC, nn.Module):
|
|||||||
|
|
||||||
def remote_grad(self) -> Dict[int, List[torch.Tensor]]:
|
def remote_grad(self) -> Dict[int, List[torch.Tensor]]:
|
||||||
grads = {}
|
grads = {}
|
||||||
for stage_id in self.pp_rank_to_worker_rref:
|
actual_stage_num = self._get_actual_stage_num()
|
||||||
|
for stage_id in range(actual_stage_num):
|
||||||
grads[stage_id] = []
|
grads[stage_id] = []
|
||||||
worker_rref = self.pp_rank_to_worker_rref[stage_id]
|
worker_rref = self.pp_rank_to_worker_rref[stage_id]
|
||||||
for grad in worker_rref.rpc_sync().get_parameter_gradients():
|
for grad in worker_rref.rpc_sync().get_parameter_gradients():
|
||||||
@ -784,7 +826,7 @@ class PipelineEngineBase(ABC, nn.Module):
|
|||||||
# collect forward result
|
# collect forward result
|
||||||
forward_result = self._collect_forward_result(output_pp_ranks, ret_future)
|
forward_result = self._collect_forward_result(output_pp_ranks, ret_future)
|
||||||
|
|
||||||
if not forward_only and labels is not None:
|
if not forward_only and hasattr(self, 'optimizer_class'):
|
||||||
# wait for all step
|
# wait for all step
|
||||||
for pp_rank in self.pp_rank_to_worker_rref:
|
for pp_rank in self.pp_rank_to_worker_rref:
|
||||||
worker_rref = self.pp_rank_to_worker_rref[pp_rank]
|
worker_rref = self.pp_rank_to_worker_rref[pp_rank]
|
||||||
@ -793,9 +835,8 @@ class PipelineEngineBase(ABC, nn.Module):
|
|||||||
return forward_result
|
return forward_result
|
||||||
|
|
||||||
def initialize_optimizer(self, optimizer_class: type, **kwargs):
|
def initialize_optimizer(self, optimizer_class: type, **kwargs):
|
||||||
actual_stage_num = self._get_actual_stage_num()
|
|
||||||
self.optimizer_class = optimizer_class
|
self.optimizer_class = optimizer_class
|
||||||
for pp_rank in range(actual_stage_num):
|
for pp_rank in self.pp_rank_to_worker_rref:
|
||||||
worker_rref = self.pp_rank_to_worker_rref[pp_rank]
|
worker_rref = self.pp_rank_to_worker_rref[pp_rank]
|
||||||
worker_rref.remote().initialize_optimizer(optimizer_class, **kwargs)
|
worker_rref.remote().initialize_optimizer(optimizer_class, **kwargs)
|
||||||
|
|
||||||
|
@ -1,10 +1,12 @@
|
|||||||
from typing import List, Callable, Dict
|
from typing import List, Callable, Dict
|
||||||
|
|
||||||
import torch.nn as nn
|
import torch
|
||||||
|
import torch.distributed as dist
|
||||||
from torch.futures import Future
|
from torch.futures import Future
|
||||||
from torch._C._distributed_rpc import PyRRef
|
from torch._C._distributed_rpc import PyRRef
|
||||||
|
|
||||||
from colossalai.pipeline.rpc._pipeline_base import PipelineEngineBase, WorkerBase, UniqueKey, Phase
|
from colossalai.pipeline.rpc._pipeline_base import PipelineEngineBase, WorkerBase, UniqueKey, Phase, WorkItem
|
||||||
|
from colossalai.pipeline.pipeline_process_group import ppg
|
||||||
|
|
||||||
# Implementation of different Pipeline schedule
|
# Implementation of different Pipeline schedule
|
||||||
# <strategy>Worker defines the worker for each stage
|
# <strategy>Worker defines the worker for each stage
|
||||||
@ -35,7 +37,7 @@ class FillDrainWorker(WorkerBase):
|
|||||||
class FillDrainPipelineEngine(PipelineEngineBase):
|
class FillDrainPipelineEngine(PipelineEngineBase):
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
module_partitions: List[nn.Module],
|
partition_fn: Callable,
|
||||||
stage_num: int,
|
stage_num: int,
|
||||||
num_microbatches: int,
|
num_microbatches: int,
|
||||||
device: str,
|
device: str,
|
||||||
@ -49,8 +51,8 @@ class FillDrainPipelineEngine(PipelineEngineBase):
|
|||||||
"if you use interleaving strategy, make sure 'num_microbatches' is a multiple of stage_num!"
|
"if you use interleaving strategy, make sure 'num_microbatches' is a multiple of stage_num!"
|
||||||
use_1F1B = False
|
use_1F1B = False
|
||||||
|
|
||||||
super().__init__(FillDrainWorker, module_partitions, stage_num, num_microbatches, device, use_1F1B, chunk,
|
super().__init__(FillDrainWorker, partition_fn, stage_num, num_microbatches, device, use_1F1B, chunk, criterion,
|
||||||
criterion, metric, checkpoint)
|
metric, checkpoint)
|
||||||
|
|
||||||
|
|
||||||
class OneFOneBWorker(WorkerBase):
|
class OneFOneBWorker(WorkerBase):
|
||||||
@ -94,7 +96,7 @@ class OneFOneBWorker(WorkerBase):
|
|||||||
class OneFOneBPipelineEngine(PipelineEngineBase):
|
class OneFOneBPipelineEngine(PipelineEngineBase):
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
module_partitions: List[nn.Module],
|
partition_fn: Callable,
|
||||||
stage_num: int,
|
stage_num: int,
|
||||||
num_microbatches: int,
|
num_microbatches: int,
|
||||||
device: str,
|
device: str,
|
||||||
@ -106,10 +108,11 @@ class OneFOneBPipelineEngine(PipelineEngineBase):
|
|||||||
if chunk > 1:
|
if chunk > 1:
|
||||||
assert num_microbatches % stage_num == 0, \
|
assert num_microbatches % stage_num == 0, \
|
||||||
"if you use interleaving strategy, make sure 'num_microbatches' is a multiple of stage_num!"
|
"if you use interleaving strategy, make sure 'num_microbatches' is a multiple of stage_num!"
|
||||||
|
assert num_microbatches > stage_num * chunk, "num_microbatches must be greater than stage_num * chunk"
|
||||||
use_1F1B = True
|
use_1F1B = True
|
||||||
|
|
||||||
super().__init__(OneFOneBWorker, module_partitions, stage_num, num_microbatches, device, use_1F1B, chunk,
|
super().__init__(OneFOneBWorker, partition_fn, stage_num, num_microbatches, device, use_1F1B, chunk, criterion,
|
||||||
criterion, metric, checkpoint)
|
metric, checkpoint)
|
||||||
|
|
||||||
|
|
||||||
class ChimeraWorker(WorkerBase):
|
class ChimeraWorker(WorkerBase):
|
||||||
@ -139,21 +142,16 @@ class ChimeraWorker(WorkerBase):
|
|||||||
stage_num = self.actual_stage_num
|
stage_num = self.actual_stage_num
|
||||||
real_microbatch_num = self.num_microbatches // 2
|
real_microbatch_num = self.num_microbatches // 2
|
||||||
|
|
||||||
if self.forward_times < real_microbatch_num:
|
forward_block_size = 1 if self.num_microbatches < stage_num else self.num_microbatches // stage_num
|
||||||
if (pp_rank + 1) % stage_num == 0: # last rank
|
forward_block_num = self.forward_times // forward_block_size
|
||||||
forward_blocks = self.forward_times // (self.num_microbatches // stage_num)
|
|
||||||
if forward_blocks > self.backward_times:
|
if self.forward_times >= real_microbatch_num or \
|
||||||
|
((pp_rank + 1) % stage_num == 0 and forward_block_num > self.backward_times):
|
||||||
target_phase = Phase.BACKWARD
|
target_phase = Phase.BACKWARD
|
||||||
target_microbatch_id = self.backward_times
|
target_microbatch_id = self.backward_times
|
||||||
else:
|
|
||||||
target_phase = Phase.FORWARD
|
|
||||||
target_microbatch_id = self.forward_times
|
|
||||||
else: # others
|
else: # others
|
||||||
target_phase = Phase.FORWARD
|
target_phase = Phase.FORWARD
|
||||||
target_microbatch_id = self.forward_times
|
target_microbatch_id = self.forward_times
|
||||||
else:
|
|
||||||
target_phase = Phase.BACKWARD
|
|
||||||
target_microbatch_id = self.backward_times
|
|
||||||
|
|
||||||
# In up pipeline, microbatch_id to consume is 0, 2, 4 (2n)
|
# In up pipeline, microbatch_id to consume is 0, 2, 4 (2n)
|
||||||
# In down pipeline, microbatch_id to consume is 1, 3, 5 (2n + 1)
|
# In down pipeline, microbatch_id to consume is 1, 3, 5 (2n + 1)
|
||||||
@ -164,22 +162,85 @@ class ChimeraWorker(WorkerBase):
|
|||||||
|
|
||||||
with self.work_list_condition_lock:
|
with self.work_list_condition_lock:
|
||||||
self.work_list_condition_lock.wait_for(lambda: target_key in self.work_list)
|
self.work_list_condition_lock.wait_for(lambda: target_key in self.work_list)
|
||||||
|
|
||||||
return target_key
|
return target_key
|
||||||
|
|
||||||
|
def _initialize_partition(self):
|
||||||
|
# In order to ensure the down pipeline share the same parameter
|
||||||
|
# with the up pipeline, partition of down partition will be copied
|
||||||
|
# from corresponding up stage
|
||||||
|
pp_rank = self.pp_rank
|
||||||
|
stage_num = self.actual_stage_num
|
||||||
|
device = self.device
|
||||||
|
if pp_rank < stage_num:
|
||||||
|
super()._initialize_partition()
|
||||||
|
else:
|
||||||
|
# if it is down pipeline, create partition by origin method
|
||||||
|
co_up_pp_worker_rref = self.pp_rank_to_worker_rref[pp_rank - stage_num]
|
||||||
|
# get the coresponding model state dict and wait for its init
|
||||||
|
state_dict = co_up_pp_worker_rref.rpc_sync().get_partition_state_dict()
|
||||||
|
super()._initialize_partition()
|
||||||
|
self.module_partition.load_state_dict(state_dict)
|
||||||
|
|
||||||
|
# init group for chimera in ppg
|
||||||
|
ppg.get_chimera_all_reduce_group(pp_rank)
|
||||||
|
|
||||||
def is_first_stage(self):
|
def is_first_stage(self):
|
||||||
return (self.pp_rank % self.actual_stage_num) == 0
|
return (self.pp_rank % self.actual_stage_num) == 0
|
||||||
|
|
||||||
def is_last_stage(self):
|
def is_last_stage(self):
|
||||||
return (self.pp_rank % self.actual_stage_num) == self.actual_stage_num - 1
|
return (self.pp_rank % self.actual_stage_num) == self.actual_stage_num - 1
|
||||||
|
|
||||||
|
def _is_last_step(self, work_item: WorkItem) -> bool:
|
||||||
|
if work_item.forward_only:
|
||||||
|
last_phase = Phase.FORWARD
|
||||||
|
else:
|
||||||
|
last_phase = Phase.BACKWARD
|
||||||
|
is_last_phase = work_item.phase == last_phase
|
||||||
|
last_microbatch_id = self.num_microbatches - 1
|
||||||
|
if self.pp_rank < self.actual_stage_num:
|
||||||
|
last_microbatch_id -= 1
|
||||||
|
is_last_microbatch = work_item.microbatch_id == last_microbatch_id
|
||||||
|
return is_last_phase and is_last_microbatch
|
||||||
|
|
||||||
|
def _get_step_order(self) -> List[int]:
|
||||||
|
# TODO : If you want to extend it to multi head chimera, overwrite here
|
||||||
|
stage_num = self.actual_stage_num
|
||||||
|
pp_rank = self.pp_rank
|
||||||
|
# pp_rank in the same device
|
||||||
|
local_device_pp_ranks = [pp_rank, stage_num * 2 - pp_rank - 1]
|
||||||
|
local_device_pp_ranks.sort(reverse=min(local_device_pp_ranks) < stage_num // 2)
|
||||||
|
return local_device_pp_ranks
|
||||||
|
|
||||||
|
def _hook_before_step(self):
|
||||||
|
pp_rank = self.pp_rank
|
||||||
|
|
||||||
|
orders = self._get_step_order()
|
||||||
|
step_index = orders.index(pp_rank)
|
||||||
|
|
||||||
|
# if currrent pp_rank is not the first to do step
|
||||||
|
# wait its previous pp_rank finish step
|
||||||
|
|
||||||
|
all_reduce_group = ppg.get_chimera_all_reduce_group(self.pp_rank)
|
||||||
|
grads = self.get_parameter_gradients()
|
||||||
|
|
||||||
|
# print(self.pp_rank, "begin all reduce", torch.cuda.max_memory_allocated(ppg.get_local_pp_rank()), torch.cuda.max_memory_reserved(ppg.get_local_pp_rank()))
|
||||||
|
if step_index == 1:
|
||||||
|
ppg.chimera_step_lock.acquire()
|
||||||
|
|
||||||
|
print(f'rank_{self.pp_rank} before all reduce')
|
||||||
|
dist.all_reduce_coalesced(grads, group=all_reduce_group, async_op=False)
|
||||||
|
print(f'rank_{self.pp_rank} after all reduce')
|
||||||
|
|
||||||
|
if step_index == 0:
|
||||||
|
ppg.chimera_step_lock.release()
|
||||||
|
|
||||||
|
|
||||||
class ChimeraPipelineEngine(PipelineEngineBase):
|
class ChimeraPipelineEngine(PipelineEngineBase):
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
module_partitions,
|
partition_fn: Callable,
|
||||||
stage_num,
|
stage_num: int,
|
||||||
num_microbatches,
|
num_microbatches: int,
|
||||||
device: str,
|
device: str,
|
||||||
criterion: Callable = None,
|
criterion: Callable = None,
|
||||||
metric: Callable = None,
|
metric: Callable = None,
|
||||||
@ -189,11 +250,12 @@ class ChimeraPipelineEngine(PipelineEngineBase):
|
|||||||
"In Chimera, num_microbatches must be the multiply of stage_num!"
|
"In Chimera, num_microbatches must be the multiply of stage_num!"
|
||||||
use_1F1B = False
|
use_1F1B = False
|
||||||
chunk = 1
|
chunk = 1
|
||||||
super().__init__(ChimeraWorker, module_partitions, stage_num, num_microbatches, device, use_1F1B, chunk,
|
|
||||||
criterion, metric, checkpoint)
|
super().__init__(ChimeraWorker, partition_fn, stage_num, num_microbatches, device, use_1F1B, chunk, criterion,
|
||||||
|
metric, checkpoint)
|
||||||
|
|
||||||
def _consume_constraint(self, microbatch_id: int, forward_only: bool, ret_future: Dict[PyRRef, List[Future]],
|
def _consume_constraint(self, microbatch_id: int, forward_only: bool, ret_future: Dict[PyRRef, List[Future]],
|
||||||
input_worker_rrefs: List[PyRRef], output_worker_rrefs: List[PyRRef]):
|
input_pp_ranks: List[PyRRef], output_pp_ranks: List[PyRRef]):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def _create_pp_rank_to_rpc_worker_id(self) -> None:
|
def _create_pp_rank_to_rpc_worker_id(self) -> None:
|
||||||
@ -254,7 +316,6 @@ class ChimeraPipelineEngine(PipelineEngineBase):
|
|||||||
|
|
||||||
up_key = UniqueKey(up_last_microbatch_id, Phase.BACKWARD)
|
up_key = UniqueKey(up_last_microbatch_id, Phase.BACKWARD)
|
||||||
down_key = UniqueKey(down_last_microbatch_id, Phase.BACKWARD)
|
down_key = UniqueKey(down_last_microbatch_id, Phase.BACKWARD)
|
||||||
|
|
||||||
up_worker_rref.rpc_sync().get_output_by_key(up_key)
|
up_worker_rref.rpc_sync().get_output_by_key(up_key)
|
||||||
down_worker_rref.rpc_sync().get_output_by_key(down_key)
|
down_worker_rref.rpc_sync().get_output_by_key(down_key)
|
||||||
|
|
||||||
|
Binary file not shown.
@ -8,8 +8,13 @@ import torch.multiprocessing as mp
|
|||||||
import torch.distributed.rpc as rpc
|
import torch.distributed.rpc as rpc
|
||||||
from torch.optim import SGD, Adam, RMSprop, Optimizer
|
from torch.optim import SGD, Adam, RMSprop, Optimizer
|
||||||
from torch._C._distributed_rpc import _is_current_rpc_agent_set
|
from torch._C._distributed_rpc import _is_current_rpc_agent_set
|
||||||
|
import torch.distributed as dist
|
||||||
from colorama import Back, Style
|
from colorama import Back, Style
|
||||||
|
|
||||||
|
from colossalai.pipeline.pipeline_process_group import ppg
|
||||||
|
from colossalai.logging import disable_existing_loggers
|
||||||
|
from colossalai import launch
|
||||||
|
|
||||||
rpc_is_initialized = _is_current_rpc_agent_set
|
rpc_is_initialized = _is_current_rpc_agent_set
|
||||||
|
|
||||||
|
|
||||||
@ -25,12 +30,15 @@ class RpcTestModel(nn.Module):
|
|||||||
self.rank = stage_id
|
self.rank = stage_id
|
||||||
self.is_last_rank = stage_id == actual_stage_num - 1
|
self.is_last_rank = stage_id == actual_stage_num - 1
|
||||||
self.linear_name = f'linear_{stage_id}'
|
self.linear_name = f'linear_{stage_id}'
|
||||||
|
|
||||||
if stage_id == 0:
|
if stage_id == 0:
|
||||||
setattr(self, self.linear_name, nn.Linear(feat_num, h))
|
linear = nn.Linear(feat_num, h)
|
||||||
elif stage_id == actual_stage_num - 1:
|
elif stage_id == actual_stage_num - 1:
|
||||||
setattr(self, self.linear_name, nn.Linear(h, 1))
|
linear = nn.Linear(h, 1)
|
||||||
else:
|
else:
|
||||||
setattr(self, self.linear_name, nn.Linear(h, h))
|
linear = nn.Linear(h, h)
|
||||||
|
|
||||||
|
setattr(self, self.linear_name, linear)
|
||||||
|
|
||||||
def forward(self, x) -> torch.Tensor:
|
def forward(self, x) -> torch.Tensor:
|
||||||
linear: nn.Module = getattr(self, self.linear_name)
|
linear: nn.Module = getattr(self, self.linear_name)
|
||||||
@ -46,6 +54,8 @@ def parse_args():
|
|||||||
parser.add_argument('--epoch', type=int, default=1)
|
parser.add_argument('--epoch', type=int, default=1)
|
||||||
parser.add_argument('--world_size', type=int, default=2)
|
parser.add_argument('--world_size', type=int, default=2)
|
||||||
parser.add_argument('--batch_size', type=int, default=16)
|
parser.add_argument('--batch_size', type=int, default=16)
|
||||||
|
parser.add_argument('--dp_degree', type=int, default=1)
|
||||||
|
parser.add_argument('--tp_degree', type=int, default=1)
|
||||||
parser.add_argument('--num_microbatches', type=int, default=2)
|
parser.add_argument('--num_microbatches', type=int, default=2)
|
||||||
parser.add_argument('--chunk', type=int, default=1)
|
parser.add_argument('--chunk', type=int, default=1)
|
||||||
parser.add_argument('--use_checkpoint', action='store_true')
|
parser.add_argument('--use_checkpoint', action='store_true')
|
||||||
@ -74,16 +84,24 @@ def run_worker(rank, args, master_func):
|
|||||||
os.environ['MASTER_ADDR'] = args.master_addr
|
os.environ['MASTER_ADDR'] = args.master_addr
|
||||||
os.environ['MASTER_PORT'] = args.master_port
|
os.environ['MASTER_PORT'] = args.master_port
|
||||||
|
|
||||||
# config rpc
|
device = args.device
|
||||||
# if cuda is used, set_device_map is a must is configured
|
|
||||||
# for cuda is not supported in torch rpc by default
|
|
||||||
options = rpc.TensorPipeRpcBackendOptions(num_worker_threads=args.num_worker_threads)
|
|
||||||
|
|
||||||
world_size = args.world_size
|
world_size = args.world_size
|
||||||
for rank_idx in range(world_size):
|
dp_degree = args.dp_degree
|
||||||
options.set_device_map(f'work{rank_idx}', {rank: rank_idx})
|
tp_degree = args.tp_degree
|
||||||
|
num_worker_threads = args.num_worker_threads
|
||||||
|
host = args.master_addr
|
||||||
|
port = args.master_port
|
||||||
|
backend = 'nccl' if device == 'cuda' else 'gloo'
|
||||||
|
|
||||||
rpc.init_rpc(name=f'work{rank}', rank=rank, world_size=world_size, rpc_backend_options=options)
|
disable_existing_loggers()
|
||||||
|
|
||||||
|
launch(dict(), rank, world_size, host, int(port), backend, verbose=False)
|
||||||
|
ppg.set_global_info(rank=rank,
|
||||||
|
world_size=world_size,
|
||||||
|
dp_degree=dp_degree,
|
||||||
|
tp_degree=tp_degree,
|
||||||
|
num_worker_threads=num_worker_threads,
|
||||||
|
device=device)
|
||||||
|
|
||||||
# in rpc mode, only rank 0 is needed to be coded
|
# in rpc mode, only rank 0 is needed to be coded
|
||||||
if rank == 0:
|
if rank == 0:
|
||||||
|
@ -1,9 +1,21 @@
|
|||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
import torch.autograd as autograd
|
||||||
|
|
||||||
from colossalai.pipeline.rpc._pipeline_schedule import FillDrainPipelineEngine, OneFOneBPipelineEngine, ChimeraPipelineEngine
|
from colossalai.pipeline.rpc import ChimeraPipelineEngine
|
||||||
|
from colossalai.testing import assert_close
|
||||||
from rpc_test_utils import rpc_run, parse_args, RpcTestModel
|
from rpc_test_utils import rpc_run, parse_args, RpcTestModel
|
||||||
|
|
||||||
|
# global variable for model created
|
||||||
|
feat_num = 100
|
||||||
|
h = 100
|
||||||
|
|
||||||
|
|
||||||
|
def partition(pp_rank: int, chunk: int, stage_num: int):
|
||||||
|
torch.manual_seed(1024)
|
||||||
|
partition = RpcTestModel(pp_rank, stage_num, feat_num, h)
|
||||||
|
return partition
|
||||||
|
|
||||||
|
|
||||||
def run_master(args):
|
def run_master(args):
|
||||||
torch.manual_seed(100)
|
torch.manual_seed(100)
|
||||||
@ -17,23 +29,51 @@ def run_master(args):
|
|||||||
use_checkpoint = False
|
use_checkpoint = False
|
||||||
|
|
||||||
sample_num = 1024
|
sample_num = 1024
|
||||||
feat_num = 10
|
|
||||||
h = 10
|
|
||||||
batch_size = 1024
|
batch_size = 1024
|
||||||
|
|
||||||
assert sample_num % batch_size == 0
|
assert sample_num % batch_size == 0
|
||||||
|
|
||||||
module_partitions = [RpcTestModel(pp_rank, actual_stage_num, feat_num, h) for pp_rank in range(actual_stage_num)]
|
engine = ChimeraPipelineEngine(partition_fn=partition,
|
||||||
engine = ChimeraPipelineEngine(module_partitions=module_partitions,
|
|
||||||
stage_num=stage_num,
|
stage_num=stage_num,
|
||||||
num_microbatches=num_microbatches,
|
num_microbatches=num_microbatches,
|
||||||
device=device,
|
device=device,
|
||||||
checkpoint=use_checkpoint)
|
checkpoint=use_checkpoint)
|
||||||
|
engine.initialize_optimizer(torch.optim.Adam, lr=1e-3)
|
||||||
|
|
||||||
input_sample = torch.randn((sample_num, feat_num), device=device)
|
input_sample = torch.randn((sample_num, feat_num), device=device)
|
||||||
|
|
||||||
for _ in range(epoch):
|
forward_result = engine.forward_backward(input_sample)
|
||||||
_ = engine.forward_backward(input_sample, forward_only=False)
|
|
||||||
|
cuda_rpc_result = []
|
||||||
|
single_result = []
|
||||||
|
actual_stage_num = engine._get_actual_stage_num()
|
||||||
|
|
||||||
|
# compute forward result and backward grad of parameters in cuda rpc
|
||||||
|
cuda_rpc_result.append(sum(forward_result[0]))
|
||||||
|
grad = engine.remote_grad()
|
||||||
|
for stage_id in range(actual_stage_num):
|
||||||
|
for p in grad[stage_id]:
|
||||||
|
cuda_rpc_result.append(p)
|
||||||
|
|
||||||
|
# compute forward result and backward grad of parameters just in rank_0
|
||||||
|
test_model = nn.Sequential(
|
||||||
|
*[partition(pp_rank, chunk, actual_stage_num) for pp_rank in range(actual_stage_num)]).to(device)
|
||||||
|
# input_sample = input_sample[len(input_sample) // 2:]
|
||||||
|
input_sample = input_sample.requires_grad_()
|
||||||
|
out_val = test_model(input_sample).sum()
|
||||||
|
autograd.backward(out_val)
|
||||||
|
single_result.append(out_val)
|
||||||
|
for p in test_model.parameters():
|
||||||
|
single_result.append(p.grad)
|
||||||
|
|
||||||
|
# print("my")
|
||||||
|
# print(cuda_rpc_result[1])
|
||||||
|
# print("answer:")
|
||||||
|
# print(single_result[1])
|
||||||
|
|
||||||
|
# assert len(cuda_rpc_result) == len(single_result)
|
||||||
|
# for r_c, r_s in zip(cuda_rpc_result, single_result):
|
||||||
|
# assert_close(r_c, r_s, 0.001, 0.001)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
@ -7,6 +7,16 @@ from colossalai.pipeline.rpc._pipeline_schedule import FillDrainPipelineEngine,
|
|||||||
from colossalai.testing import assert_close
|
from colossalai.testing import assert_close
|
||||||
from rpc_test_utils import rpc_run, parse_args, RpcTestModel
|
from rpc_test_utils import rpc_run, parse_args, RpcTestModel
|
||||||
|
|
||||||
|
# global variable for model created
|
||||||
|
feat_num = 100
|
||||||
|
h = 100
|
||||||
|
|
||||||
|
|
||||||
|
def partition(pp_rank: int, chunk: int, stage_num: int):
|
||||||
|
torch.manual_seed(1024)
|
||||||
|
partition = RpcTestModel(pp_rank, stage_num, feat_num, h)
|
||||||
|
return partition
|
||||||
|
|
||||||
|
|
||||||
def run_master(args):
|
def run_master(args):
|
||||||
torch.manual_seed(100)
|
torch.manual_seed(100)
|
||||||
@ -20,20 +30,14 @@ def run_master(args):
|
|||||||
optimizer_class = globals()[args.optimizer]
|
optimizer_class = globals()[args.optimizer]
|
||||||
|
|
||||||
lr = 1e-3
|
lr = 1e-3
|
||||||
|
|
||||||
sample_num = 1024
|
sample_num = 1024
|
||||||
feat_num = 100
|
|
||||||
h = 100
|
|
||||||
batch_size = 1024
|
batch_size = 1024
|
||||||
|
|
||||||
assert sample_num % batch_size == 0
|
assert sample_num % batch_size == 0
|
||||||
batch_num = sample_num // batch_size
|
|
||||||
|
|
||||||
input_sample = torch.randn((sample_num, feat_num), device=device)
|
input_sample = torch.randn((sample_num, feat_num), device=device)
|
||||||
|
|
||||||
module_partitions = [RpcTestModel(pp_rank, actual_stage_num, feat_num, h) for pp_rank in range(actual_stage_num)]
|
engine = OneFOneBPipelineEngine(partition_fn=partition,
|
||||||
|
|
||||||
engine = OneFOneBPipelineEngine(module_partitions=module_partitions,
|
|
||||||
stage_num=stage_num,
|
stage_num=stage_num,
|
||||||
num_microbatches=num_microbatches,
|
num_microbatches=num_microbatches,
|
||||||
device=device,
|
device=device,
|
||||||
@ -55,7 +59,8 @@ def run_master(args):
|
|||||||
cuda_rpc_result.append(p)
|
cuda_rpc_result.append(p)
|
||||||
|
|
||||||
# compute forward result and backward grad of parameters just in rank_0
|
# compute forward result and backward grad of parameters just in rank_0
|
||||||
test_model = nn.Sequential(*module_partitions).to(device)
|
test_model = nn.Sequential(
|
||||||
|
*[partition(pp_rank, chunk, actual_stage_num) for pp_rank in range(actual_stage_num)]).to(device)
|
||||||
optimizer: Optimizer = optimizer_class(test_model.parameters(), lr=lr)
|
optimizer: Optimizer = optimizer_class(test_model.parameters(), lr=lr)
|
||||||
input_sample = input_sample.requires_grad_()
|
input_sample = input_sample.requires_grad_()
|
||||||
out_val = test_model(input_sample).sum()
|
out_val = test_model(input_sample).sum()
|
||||||
|
@ -18,17 +18,30 @@ from colossalai.trainer import Trainer, hooks
|
|||||||
from colossalai.utils import MultiTimer, get_dataloader
|
from colossalai.utils import MultiTimer, get_dataloader
|
||||||
from colossalai.context import ParallelMode
|
from colossalai.context import ParallelMode
|
||||||
from colossalai.pipeline.pipelinable import PipelinableContext, PipelinableModel
|
from colossalai.pipeline.pipelinable import PipelinableContext, PipelinableModel
|
||||||
from colossalai.pipeline.rpc._pipeline_schedule import OneFOneBPipelineEngine
|
from colossalai.pipeline.rpc import OneFOneBPipelineEngine, ChimeraPipelineEngine
|
||||||
|
from colossalai.pipeline.pipeline_process_group import ppg
|
||||||
|
|
||||||
|
|
||||||
def flatten(x):
|
def flatten(x):
|
||||||
return torch.flatten(x, 1)
|
return torch.flatten(x, 1)
|
||||||
|
|
||||||
|
|
||||||
class Flatten(nn.Module):
|
def partition(pp_rank: int, chunk: int, stage_num: int):
|
||||||
|
pipelinable = PipelinableContext()
|
||||||
|
|
||||||
def forward(self, x):
|
# build model partitions
|
||||||
return torch.flatten(x, start_dim=1)
|
with pipelinable:
|
||||||
|
# input : [B, 3, 32, 32]
|
||||||
|
_ = resnet50()
|
||||||
|
|
||||||
|
pipelinable.policy = "customized"
|
||||||
|
|
||||||
|
exec_seq = [
|
||||||
|
'conv1', 'bn1', 'relu', 'maxpool', 'layer1', 'layer2', 'layer3', 'layer4', 'avgpool', (flatten, "behind"), 'fc'
|
||||||
|
]
|
||||||
|
pipelinable.to_layer_list(exec_seq)
|
||||||
|
partition = pipelinable.partition(chunk, stage_num, pp_rank)
|
||||||
|
return partition
|
||||||
|
|
||||||
|
|
||||||
def run_master(args):
|
def run_master(args):
|
||||||
@ -39,37 +52,12 @@ def run_master(args):
|
|||||||
stage_num = world_size
|
stage_num = world_size
|
||||||
num_microbatches = args.num_microbatches
|
num_microbatches = args.num_microbatches
|
||||||
|
|
||||||
assert chunk == 1
|
|
||||||
|
|
||||||
pipelinable = PipelinableContext()
|
|
||||||
|
|
||||||
# build model partitions
|
|
||||||
with pipelinable:
|
|
||||||
# input : [B, 3, 32, 32]
|
|
||||||
model = resnet50()
|
|
||||||
|
|
||||||
exec_seq = [
|
|
||||||
'conv1', 'bn1', 'relu', 'maxpool', 'layer1', 'layer2', 'layer3', 'layer4', 'avgpool', (flatten, "behind"), 'fc'
|
|
||||||
]
|
|
||||||
pipelinable.to_layer_list(exec_seq)
|
|
||||||
module_partitions: List[PipelinableModel] = [
|
|
||||||
pipelinable.partition(chunk, stage_num, pp_rank) for pp_rank in range(world_size)
|
|
||||||
]
|
|
||||||
|
|
||||||
# build dataloader
|
# build dataloader
|
||||||
root = os.environ.get('DATA', './data')
|
root = os.environ.get('DATA', './data')
|
||||||
train_dataloader, test_dataloader = build_cifar(batch_size, root, padding=4, crop=32, resize=32)
|
train_dataloader, test_dataloader = build_cifar(batch_size, root, padding=4, crop=32, resize=32)
|
||||||
criterion = nn.CrossEntropyLoss()
|
criterion = nn.CrossEntropyLoss()
|
||||||
|
|
||||||
partition_1 = module_partitions[0]
|
pp_engine = OneFOneBPipelineEngine(partition_fn=partition,
|
||||||
partition_2 = []
|
|
||||||
for model in module_partitions[1]._module_list:
|
|
||||||
partition_2.append(model)
|
|
||||||
partition_2.insert(len(partition_2) - 1, Flatten())
|
|
||||||
partition_2 = nn.Sequential(*partition_2)
|
|
||||||
module_partitions = [partition_1, partition_2]
|
|
||||||
|
|
||||||
pp_engine = OneFOneBPipelineEngine(module_partitions=module_partitions,
|
|
||||||
stage_num=stage_num,
|
stage_num=stage_num,
|
||||||
num_microbatches=num_microbatches,
|
num_microbatches=num_microbatches,
|
||||||
device=device,
|
device=device,
|
||||||
|
@ -4,6 +4,16 @@ from torch import nn
|
|||||||
from colossalai.pipeline.rpc._pipeline_schedule import FillDrainPipelineEngine, OneFOneBPipelineEngine
|
from colossalai.pipeline.rpc._pipeline_schedule import FillDrainPipelineEngine, OneFOneBPipelineEngine
|
||||||
from rpc_test_utils import rpc_run, parse_args, RpcTestModel
|
from rpc_test_utils import rpc_run, parse_args, RpcTestModel
|
||||||
|
|
||||||
|
# global variable for model created
|
||||||
|
feat_num = 100
|
||||||
|
h = 100
|
||||||
|
|
||||||
|
|
||||||
|
def partition(pp_rank: int, chunk: int, stage_num: int):
|
||||||
|
torch.manual_seed(1024)
|
||||||
|
partition = RpcTestModel(pp_rank, stage_num, feat_num, h)
|
||||||
|
return partition
|
||||||
|
|
||||||
|
|
||||||
def run_master(args):
|
def run_master(args):
|
||||||
torch.manual_seed(100)
|
torch.manual_seed(100)
|
||||||
@ -13,22 +23,16 @@ def run_master(args):
|
|||||||
stage_num = args.world_size
|
stage_num = args.world_size
|
||||||
chunk = args.chunk
|
chunk = args.chunk
|
||||||
num_microbatches = args.num_microbatches
|
num_microbatches = args.num_microbatches
|
||||||
actual_stage_num = stage_num * chunk
|
|
||||||
use_checkpoint = args.use_checkpoint
|
use_checkpoint = args.use_checkpoint
|
||||||
|
|
||||||
sample_num = 1024
|
sample_num = 1024
|
||||||
feat_num = 10
|
|
||||||
h = 10
|
|
||||||
batch_size = 1024
|
batch_size = 1024
|
||||||
|
|
||||||
assert sample_num % batch_size == 0
|
assert sample_num % batch_size == 0
|
||||||
batch_num = sample_num // batch_size
|
|
||||||
|
|
||||||
input_sample = torch.randn((sample_num, feat_num), device=device)
|
input_sample = torch.randn((sample_num, feat_num), device=device)
|
||||||
|
|
||||||
module_partitions = [RpcTestModel(pp_rank, actual_stage_num, feat_num, h) for pp_rank in range(actual_stage_num)]
|
engine = OneFOneBPipelineEngine(partition_fn=partition,
|
||||||
|
|
||||||
engine = OneFOneBPipelineEngine(module_partitions=module_partitions,
|
|
||||||
stage_num=stage_num,
|
stage_num=stage_num,
|
||||||
num_microbatches=num_microbatches,
|
num_microbatches=num_microbatches,
|
||||||
device=device,
|
device=device,
|
||||||
|
@ -6,6 +6,15 @@ from colossalai.pipeline.rpc._pipeline_schedule import FillDrainPipelineEngine,
|
|||||||
from colossalai.testing import assert_close
|
from colossalai.testing import assert_close
|
||||||
from rpc_test_utils import rpc_run, parse_args, RpcTestModel
|
from rpc_test_utils import rpc_run, parse_args, RpcTestModel
|
||||||
|
|
||||||
|
feat_num = 100
|
||||||
|
h = 100
|
||||||
|
|
||||||
|
|
||||||
|
def partition(pp_rank: int, chunk: int, stage_num: int):
|
||||||
|
torch.manual_seed(1024)
|
||||||
|
partition = RpcTestModel(pp_rank, stage_num, feat_num, h)
|
||||||
|
return partition
|
||||||
|
|
||||||
|
|
||||||
def run_master(args):
|
def run_master(args):
|
||||||
torch.manual_seed(100)
|
torch.manual_seed(100)
|
||||||
@ -18,25 +27,20 @@ def run_master(args):
|
|||||||
num_microbatches = args.num_microbatches
|
num_microbatches = args.num_microbatches
|
||||||
|
|
||||||
sample_num = 1024
|
sample_num = 1024
|
||||||
feat_num = 100
|
|
||||||
h = 100
|
|
||||||
batch_size = 1024
|
batch_size = 1024
|
||||||
|
|
||||||
assert sample_num % batch_size == 0
|
assert sample_num % batch_size == 0
|
||||||
batch_num = sample_num // batch_size
|
|
||||||
|
|
||||||
input_sample = torch.randn((sample_num, feat_num), device=device)
|
input_sample = torch.randn((sample_num, feat_num), device=device)
|
||||||
|
|
||||||
module_partitions = [RpcTestModel(pp_rank, actual_stage_num, feat_num, h) for pp_rank in range(actual_stage_num)]
|
engine = OneFOneBPipelineEngine(partition_fn=partition,
|
||||||
|
|
||||||
engine = OneFOneBPipelineEngine(module_partitions=module_partitions,
|
|
||||||
stage_num=stage_num,
|
stage_num=stage_num,
|
||||||
num_microbatches=num_microbatches,
|
num_microbatches=num_microbatches,
|
||||||
device=device,
|
device=device,
|
||||||
chunk=chunk,
|
chunk=chunk,
|
||||||
checkpoint=use_checkpoint)
|
checkpoint=use_checkpoint)
|
||||||
|
|
||||||
forward_result = engine.forward_backward(input_sample)[0]
|
forward_result = engine.forward_backward(input_sample)
|
||||||
|
|
||||||
cuda_rpc_result = []
|
cuda_rpc_result = []
|
||||||
single_result = []
|
single_result = []
|
||||||
@ -50,7 +54,8 @@ def run_master(args):
|
|||||||
cuda_rpc_result.append(p)
|
cuda_rpc_result.append(p)
|
||||||
|
|
||||||
# compute forward result and backward grad of parameters just in rank_0
|
# compute forward result and backward grad of parameters just in rank_0
|
||||||
test_model = nn.Sequential(*module_partitions).to(device)
|
test_model = nn.Sequential(
|
||||||
|
*[partition(pp_rank, chunk, actual_stage_num) for pp_rank in range(actual_stage_num)]).to(device)
|
||||||
input_sample = input_sample.requires_grad_()
|
input_sample = input_sample.requires_grad_()
|
||||||
out_val = test_model(input_sample).sum()
|
out_val = test_model(input_sample).sum()
|
||||||
autograd.backward(out_val)
|
autograd.backward(out_val)
|
||||||
|
@ -4,7 +4,7 @@ import torch.distributed.rpc as rpc
|
|||||||
import torch.multiprocessing as mp
|
import torch.multiprocessing as mp
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from colossalai.pipeline.pipeline_process_group import PipelineProcessGroup
|
from colossalai.pipeline.pipeline_process_group import ppg
|
||||||
from colossalai.initialize import launch
|
from colossalai.initialize import launch
|
||||||
from colossalai.logging import disable_existing_loggers
|
from colossalai.logging import disable_existing_loggers
|
||||||
from rpc_test_utils import pg_parse_args, rpc_is_initialized
|
from rpc_test_utils import pg_parse_args, rpc_is_initialized
|
||||||
@ -26,7 +26,7 @@ def run_worker(rank, args):
|
|||||||
disable_existing_loggers()
|
disable_existing_loggers()
|
||||||
launch(dict(), rank, world_size, host, int(port), backend, verbose=False)
|
launch(dict(), rank, world_size, host, int(port), backend, verbose=False)
|
||||||
|
|
||||||
pg = PipelineProcessGroup(rank=rank,
|
ppg.set_global_info(rank=rank,
|
||||||
world_size=world_size,
|
world_size=world_size,
|
||||||
dp_degree=dp_degree,
|
dp_degree=dp_degree,
|
||||||
tp_degree=tp_degree,
|
tp_degree=tp_degree,
|
||||||
|
Loading…
Reference in New Issue
Block a user