mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-06-21 21:22:04 +00:00
[rpc] split with dag (#2028)
* add DAG to split_module * add comment * add test case for DAG * remove print * add DAG middleware in scheduler * add test case for scheduler * remove break * recover old lifecycle Co-authored-by: Ziyue Jiang <ziyue.jiang@gmail.com>
This commit is contained in:
parent
96134e7be3
commit
b0936e4a44
@ -117,7 +117,7 @@ def uniform_split_pass(gm: torch.fx.GraphModule, pp_size: int):
|
|||||||
return gm
|
return gm
|
||||||
|
|
||||||
|
|
||||||
def split_with_split_nodes_pass(annotated_gm: torch.fx.GraphModule):
|
def split_with_split_nodes_pass(annotated_gm: torch.fx.GraphModule, merge_output=False):
|
||||||
# TODO(lyl): use partition IR to assign partition ID to each node.
|
# TODO(lyl): use partition IR to assign partition ID to each node.
|
||||||
# Currently: analyzing graph -> annotate graph by inserting split node -> use split module pass to split graph
|
# Currently: analyzing graph -> annotate graph by inserting split node -> use split module pass to split graph
|
||||||
# In future: graph to partitions -> analyzing partition IR -> recombining partitions to get best performance -> assign partition ID to each node
|
# In future: graph to partitions -> analyzing partition IR -> recombining partitions to get best performance -> assign partition ID to each node
|
||||||
@ -129,7 +129,7 @@ def split_with_split_nodes_pass(annotated_gm: torch.fx.GraphModule):
|
|||||||
part_idx += 1
|
part_idx += 1
|
||||||
return part_idx
|
return part_idx
|
||||||
|
|
||||||
split_mod = split_module(annotated_gm, None, split_callback)
|
split_mod = split_module(annotated_gm, None, split_callback, merge_output)
|
||||||
split_submodules = []
|
split_submodules = []
|
||||||
for name, submodule in split_mod.named_modules():
|
for name, submodule in split_mod.named_modules():
|
||||||
if isinstance(submodule, torch.fx.GraphModule):
|
if isinstance(submodule, torch.fx.GraphModule):
|
||||||
|
@ -199,23 +199,16 @@ def find_user_in_partition(node, partitions, output_partitions=None, direct=Fals
|
|||||||
for partition in partitions:
|
for partition in partitions:
|
||||||
if node == partition:
|
if node == partition:
|
||||||
user_partition_names.append(partition.name)
|
user_partition_names.append(partition.name)
|
||||||
|
|
||||||
# find user with getitem call
|
# find user with getitem call
|
||||||
else:
|
else:
|
||||||
for partition in partitions:
|
for partition in partitions:
|
||||||
if node in partition.args:
|
if node in partition.args:
|
||||||
user_partition_names.append(partition.name)
|
user_partition_names.append(partition.name)
|
||||||
|
|
||||||
is_output = False
|
|
||||||
def find_output(def_node, output_node):
|
|
||||||
nonlocal is_output
|
|
||||||
if def_node == output_node:
|
|
||||||
is_output = True
|
|
||||||
|
|
||||||
if output_partitions is not None:
|
if output_partitions is not None:
|
||||||
output_node = output_partitions[0]
|
output_node = output_partitions[0]
|
||||||
torch.fx.graph.map_arg(output_node.args[0], lambda n: find_output(node, n))
|
if node.op == output_node.op:
|
||||||
|
|
||||||
if is_output:
|
|
||||||
user_partition_names.append('MODEL_OUTPUT')
|
user_partition_names.append('MODEL_OUTPUT')
|
||||||
|
|
||||||
if len(user_partition_names) > 0:
|
if len(user_partition_names) > 0:
|
||||||
|
@ -9,7 +9,7 @@ from typing import Any, Callable, Dict, List, Tuple
|
|||||||
import torch
|
import torch
|
||||||
import torch.distributed.rpc as rpc
|
import torch.distributed.rpc as rpc
|
||||||
from colossalai.pipeline.pipeline_process_group import ppg
|
from colossalai.pipeline.pipeline_process_group import ppg
|
||||||
from colossalai.pipeline.rpc.utils import (get_batch_lengths, get_real_args_kwargs, pytree_filter, pytree_map,
|
from colossalai.pipeline.rpc.utils import (get_batch_lengths, pytree_filter, pytree_map,
|
||||||
split_batch, tensor_shape_list, type_detail)
|
split_batch, tensor_shape_list, type_detail)
|
||||||
from torch import autograd, nn, optim
|
from torch import autograd, nn, optim
|
||||||
from torch._C._distributed_rpc import PyRRef
|
from torch._C._distributed_rpc import PyRRef
|
||||||
@ -20,7 +20,7 @@ class Phase(Enum):
|
|||||||
FORWARD = 0
|
FORWARD = 0
|
||||||
BACKWARD = 1
|
BACKWARD = 1
|
||||||
UPDATE = 2
|
UPDATE = 2
|
||||||
|
INPUT = 3
|
||||||
|
|
||||||
class UniqueKey:
|
class UniqueKey:
|
||||||
__slots__ = ('microbatch_id', 'phase')
|
__slots__ = ('microbatch_id', 'phase')
|
||||||
@ -128,6 +128,7 @@ class WorkerBase(ABC):
|
|||||||
# topology info
|
# topology info
|
||||||
self.producer_stage_ids: List[int] = None
|
self.producer_stage_ids: List[int] = None
|
||||||
self.consumer_stage_ids: List[int] = None
|
self.consumer_stage_ids: List[int] = None
|
||||||
|
self.input_consumer_stage_ids: List[int] = None
|
||||||
|
|
||||||
# module partitions
|
# module partitions
|
||||||
self.partition_fn = partition_fn
|
self.partition_fn = partition_fn
|
||||||
@ -135,6 +136,11 @@ class WorkerBase(ABC):
|
|||||||
self.criterion = criterion
|
self.criterion = criterion
|
||||||
self.metric = metric
|
self.metric = metric
|
||||||
|
|
||||||
|
# middleware info
|
||||||
|
self._is_input = False
|
||||||
|
self._is_output = False
|
||||||
|
self._producer_consumer_initialized = False
|
||||||
|
|
||||||
# context to maintain loop
|
# context to maintain loop
|
||||||
self._initialize_context_container()
|
self._initialize_context_container()
|
||||||
|
|
||||||
@ -164,6 +170,7 @@ class WorkerBase(ABC):
|
|||||||
self.work_list_condition_lock = threading.Condition(threading.Lock())
|
self.work_list_condition_lock = threading.Condition(threading.Lock())
|
||||||
self.output_list_condition_lock = threading.Condition(threading.Lock())
|
self.output_list_condition_lock = threading.Condition(threading.Lock())
|
||||||
self.label_lock = threading.Condition(threading.Lock())
|
self.label_lock = threading.Condition(threading.Lock())
|
||||||
|
self.producer_consumer_init_lock = threading.Condition(threading.Lock())
|
||||||
|
|
||||||
def _initialize_partition(self):
|
def _initialize_partition(self):
|
||||||
partition_fn = self.partition_fn
|
partition_fn = self.partition_fn
|
||||||
@ -182,7 +189,7 @@ class WorkerBase(ABC):
|
|||||||
# construction of partition is executed after the registion of pp_rank_to_worker_rref
|
# construction of partition is executed after the registion of pp_rank_to_worker_rref
|
||||||
self._initialize_partition()
|
self._initialize_partition()
|
||||||
|
|
||||||
def get_output_by_key(self, key: UniqueKey) -> Any:
|
def get_output_by_key(self, key: UniqueKey, recv_rank=None) -> Any:
|
||||||
with self.output_list_condition_lock:
|
with self.output_list_condition_lock:
|
||||||
self.output_list_condition_lock.wait_for(lambda: key in self.output_list)
|
self.output_list_condition_lock.wait_for(lambda: key in self.output_list)
|
||||||
output_work_item = self.output_list[key]
|
output_work_item = self.output_list[key]
|
||||||
@ -191,8 +198,9 @@ class WorkerBase(ABC):
|
|||||||
if isinstance(output, Future):
|
if isinstance(output, Future):
|
||||||
output = output.wait()
|
output = output.wait()
|
||||||
|
|
||||||
output_work_item.refcount += 1
|
# output_work_item.refcount += 1
|
||||||
|
|
||||||
|
# TODO(jiangziyue) redesign lifecycle management for DAG scheduler
|
||||||
# all consumers have been satisfied, the work_item can be released
|
# all consumers have been satisfied, the work_item can be released
|
||||||
with self.output_list_condition_lock:
|
with self.output_list_condition_lock:
|
||||||
if output_work_item.refcount >= len(self.consumer_stage_ids):
|
if output_work_item.refcount >= len(self.consumer_stage_ids):
|
||||||
@ -215,8 +223,10 @@ class WorkerBase(ABC):
|
|||||||
self.partition_condition_lock.wait_for(lambda: hasattr(self, 'module_partition'))
|
self.partition_condition_lock.wait_for(lambda: hasattr(self, 'module_partition'))
|
||||||
return self.module_partition.state_dict()
|
return self.module_partition.state_dict()
|
||||||
|
|
||||||
def _make_args_kwargs(self, microbatch):
|
def _make_args_kwargs(self, microbatch, merge=False):
|
||||||
if isinstance(microbatch, dict):
|
if isinstance(microbatch, dict):
|
||||||
|
if merge:
|
||||||
|
return list(microbatch.values()), {}
|
||||||
return [], microbatch
|
return [], microbatch
|
||||||
elif isinstance(microbatch, torch.Tensor):
|
elif isinstance(microbatch, torch.Tensor):
|
||||||
return [microbatch], {}
|
return [microbatch], {}
|
||||||
@ -228,16 +238,24 @@ class WorkerBase(ABC):
|
|||||||
kwargs.update(arg)
|
kwargs.update(arg)
|
||||||
else:
|
else:
|
||||||
args.append(arg)
|
args.append(arg)
|
||||||
|
if merge:
|
||||||
|
arg_lst = args
|
||||||
|
for arg in kwargs.values():
|
||||||
|
arg_lst.append(arg)
|
||||||
|
return arg_lst, {}
|
||||||
return args, kwargs
|
return args, kwargs
|
||||||
else:
|
else:
|
||||||
raise TypeError(f"Input batch can be only dict, list, tuple or tensor, but receive {type(microbatch)}")
|
raise TypeError(f"Input batch can be only dict, list, tuple or tensor, but receive {type(microbatch)}")
|
||||||
|
|
||||||
# just for first pp_rank
|
# just for first pp_rank
|
||||||
|
# TODO(jiangziyue) Consider whether this function should be protected by Lock in DAG env.
|
||||||
|
# TODO(jiangziyue) Define a Class for DAG.
|
||||||
def set_input(self, microbatch_id: int, microbatch: Tuple[Any], forward_only: bool):
|
def set_input(self, microbatch_id: int, microbatch: Tuple[Any], forward_only: bool):
|
||||||
assert self.consumer_stage_ids is not None
|
assert self.consumer_stage_ids is not None
|
||||||
key = UniqueKey(microbatch_id, Phase.FORWARD)
|
key = UniqueKey(microbatch_id, Phase.FORWARD)
|
||||||
output = self._get_future_by_device()
|
output = self._get_future_by_device()
|
||||||
|
|
||||||
|
if not self.use_middleware():
|
||||||
# make args and kwargs
|
# make args and kwargs
|
||||||
args, kwargs = self._make_args_kwargs(microbatch)
|
args, kwargs = self._make_args_kwargs(microbatch)
|
||||||
|
|
||||||
@ -246,6 +264,44 @@ class WorkerBase(ABC):
|
|||||||
with self.work_list_condition_lock:
|
with self.work_list_condition_lock:
|
||||||
self.work_list[key] = work_item
|
self.work_list[key] = work_item
|
||||||
self.work_list_condition_lock.notify_all()
|
self.work_list_condition_lock.notify_all()
|
||||||
|
else:
|
||||||
|
# make args and kwargs
|
||||||
|
arg_lst, _ = self._make_args_kwargs(microbatch, merge=True)
|
||||||
|
|
||||||
|
# first stage assign correct input into other stages
|
||||||
|
DAG = self.get_DAG()
|
||||||
|
DAG_node = DAG['input_partition']
|
||||||
|
self_input_offsets = []
|
||||||
|
recv_input_key = UniqueKey(microbatch_id, Phase.INPUT)
|
||||||
|
# notify rank which should receive extra input
|
||||||
|
offset = 0
|
||||||
|
for details in DAG_node.values():
|
||||||
|
for partition_name in details['output'].keys():
|
||||||
|
recv_rank = self.partition_name_to_pp_rank(partition_name)
|
||||||
|
if recv_rank == self.pp_rank:
|
||||||
|
self_input_offsets.append(offset)
|
||||||
|
elif recv_rank not in self.input_consumer_stage_ids:
|
||||||
|
self.input_consumer_stage_ids.append(recv_rank)
|
||||||
|
offset += 1
|
||||||
|
|
||||||
|
# set input for self rank
|
||||||
|
self_arg_lst = []
|
||||||
|
for off in self_input_offsets:
|
||||||
|
self_arg_lst.append(arg_lst[off])
|
||||||
|
|
||||||
|
work_item = WorkItem(self.pp_rank, Phase.FORWARD, self_arg_lst, {}, output, microbatch_id, None,
|
||||||
|
self.num_microbatches, forward_only)
|
||||||
|
with self.work_list_condition_lock:
|
||||||
|
self.work_list[key] = work_item
|
||||||
|
self.work_list_condition_lock.notify_all()
|
||||||
|
|
||||||
|
# put input tensor which other nodes need into output_list
|
||||||
|
work_item_remote = WorkItem(self.pp_rank, Phase.INPUT, [], {}, arg_lst, microbatch_id, None,
|
||||||
|
self.num_microbatches, forward_only)
|
||||||
|
|
||||||
|
with self.output_list_condition_lock:
|
||||||
|
self.output_list[recv_input_key] = work_item_remote
|
||||||
|
self.output_list_condition_lock.notify_all()
|
||||||
|
|
||||||
# just for last pp_rank
|
# just for last pp_rank
|
||||||
def set_labels(self, microbatch_id: int, microlabels: Any):
|
def set_labels(self, microbatch_id: int, microlabels: Any):
|
||||||
@ -268,23 +324,58 @@ class WorkerBase(ABC):
|
|||||||
self.work_list[key] = work_item
|
self.work_list[key] = work_item
|
||||||
self.work_list_condition_lock.notify_all()
|
self.work_list_condition_lock.notify_all()
|
||||||
|
|
||||||
|
# TODO(jiangziyue) Consider whether this function should be protected by Lock in DAG env.
|
||||||
def subscribe_producer(self, microbatch_id: int, forward_only: bool):
|
def subscribe_producer(self, microbatch_id: int, forward_only: bool):
|
||||||
"""
|
"""
|
||||||
You should call this function asynchronously
|
You should call this function asynchronously
|
||||||
"""
|
"""
|
||||||
assert self.producer_stage_ids is not None
|
|
||||||
producer_num = len(self.producer_stage_ids)
|
|
||||||
assert producer_num > 0, "only stage that has producers can subscribe producers"
|
|
||||||
|
|
||||||
stage_id = self.pp_rank
|
stage_id = self.pp_rank
|
||||||
subscribe_forward_futures: List[Future] = [None] * producer_num
|
|
||||||
output = self._get_future_by_device()
|
output = self._get_future_by_device()
|
||||||
|
if not self.use_middleware():
|
||||||
|
producer_num = len(self.producer_stage_ids)
|
||||||
|
subscribe_forward_futures: List[Future] = [None] * producer_num
|
||||||
for i in range(producer_num):
|
for i in range(producer_num):
|
||||||
producer_stage_id = self.producer_stage_ids[i]
|
producer_stage_id = self.producer_stage_ids[i]
|
||||||
producer_output_key = UniqueKey(microbatch_id, Phase.FORWARD)
|
producer_output_key = UniqueKey(microbatch_id, Phase.FORWARD)
|
||||||
producer_worker_rref = self.pp_rank_to_worker_rref[producer_stage_id]
|
producer_worker_rref = self.pp_rank_to_worker_rref[producer_stage_id]
|
||||||
subscribe_forward_futures[i] = producer_worker_rref.rpc_async().get_output_by_key(producer_output_key)
|
subscribe_forward_futures[i] = producer_worker_rref.rpc_async().get_output_by_key(producer_output_key)
|
||||||
|
else:
|
||||||
|
with self.work_list_condition_lock:
|
||||||
|
key = UniqueKey(microbatch_id, Phase.FORWARD)
|
||||||
|
if key in self.work_list:
|
||||||
|
return
|
||||||
|
|
||||||
|
producer_stage_ids = []
|
||||||
|
with self.producer_consumer_init_lock:
|
||||||
|
self.producer_consumer_init_lock.wait_for(lambda: self._producer_consumer_initialized)
|
||||||
|
producer_stage_ids = self.producer_stage_ids
|
||||||
|
producer_num = len(producer_stage_ids)
|
||||||
|
|
||||||
|
# TODO(jiangziyue) get single value instead of the whole output
|
||||||
|
if self.need_model_input():
|
||||||
|
producer_num += 1 # extra one(the last one) for input_tensor
|
||||||
|
subscribe_forward_futures: List[Future] = [None] * producer_num
|
||||||
|
|
||||||
|
# TODO(jiangziyue) get single value instead of the whole output
|
||||||
|
if self.need_model_input():
|
||||||
|
producer_stage_id = 0
|
||||||
|
producer_output_key = UniqueKey(microbatch_id, Phase.INPUT)
|
||||||
|
producer_worker_rref = self.pp_rank_to_worker_rref[producer_stage_id]
|
||||||
|
subscribe_forward_futures[0] = producer_worker_rref.rpc_async().get_output_by_key(producer_output_key, self.pp_rank)
|
||||||
|
|
||||||
|
for i in range(0, producer_num-1):
|
||||||
|
producer_stage_id = producer_stage_ids[i]
|
||||||
|
producer_output_key = UniqueKey(microbatch_id, Phase.FORWARD)
|
||||||
|
producer_worker_rref = self.pp_rank_to_worker_rref[producer_stage_id]
|
||||||
|
subscribe_forward_futures[i+1] = producer_worker_rref.rpc_async().get_output_by_key(producer_output_key, self.pp_rank)
|
||||||
|
|
||||||
|
else:
|
||||||
|
for i in range(producer_num):
|
||||||
|
producer_stage_id = producer_stage_ids[i]
|
||||||
|
producer_output_key = UniqueKey(microbatch_id, Phase.FORWARD)
|
||||||
|
producer_worker_rref = self.pp_rank_to_worker_rref[producer_stage_id]
|
||||||
|
#producer_partition_name = self.pp_rank_to_partition_name[producer_stage_id]
|
||||||
|
subscribe_forward_futures[i] = producer_worker_rref.rpc_async().get_output_by_key(producer_output_key, self.pp_rank)
|
||||||
|
|
||||||
work_item_from_producer = WorkItem(stage_id, Phase.FORWARD, subscribe_forward_futures, {}, output,
|
work_item_from_producer = WorkItem(stage_id, Phase.FORWARD, subscribe_forward_futures, {}, output,
|
||||||
microbatch_id, None, self.num_microbatches, forward_only)
|
microbatch_id, None, self.num_microbatches, forward_only)
|
||||||
@ -292,7 +383,7 @@ class WorkerBase(ABC):
|
|||||||
# add work_item to work_list
|
# add work_item to work_list
|
||||||
with self.work_list_condition_lock:
|
with self.work_list_condition_lock:
|
||||||
key = UniqueKey(microbatch_id, Phase.FORWARD)
|
key = UniqueKey(microbatch_id, Phase.FORWARD)
|
||||||
assert key not in self.work_list
|
if key not in self.work_list:
|
||||||
self.work_list[key] = work_item_from_producer
|
self.work_list[key] = work_item_from_producer
|
||||||
self.work_list_condition_lock.notify_all()
|
self.work_list_condition_lock.notify_all()
|
||||||
|
|
||||||
@ -334,6 +425,7 @@ class WorkerBase(ABC):
|
|||||||
self.producer_stage_ids = []
|
self.producer_stage_ids = []
|
||||||
self.consumer_stage_ids = []
|
self.consumer_stage_ids = []
|
||||||
|
|
||||||
|
if not self.use_middleware():
|
||||||
# Just for demo
|
# Just for demo
|
||||||
prev_rank = rank - 1
|
prev_rank = rank - 1
|
||||||
next_rank = rank + 1
|
next_rank = rank + 1
|
||||||
@ -341,6 +433,124 @@ class WorkerBase(ABC):
|
|||||||
self.producer_stage_ids.append(prev_rank)
|
self.producer_stage_ids.append(prev_rank)
|
||||||
if next_rank <= self.actual_stage_num - 1:
|
if next_rank <= self.actual_stage_num - 1:
|
||||||
self.consumer_stage_ids.append(next_rank)
|
self.consumer_stage_ids.append(next_rank)
|
||||||
|
else:
|
||||||
|
self.input_consumer_stage_ids = []
|
||||||
|
DAG = self.get_DAG()
|
||||||
|
DAG_node_name = self.pp_rank_to_partition_name(rank)
|
||||||
|
DAG_node = DAG[DAG_node_name]
|
||||||
|
for partition_name in DAG_node['input'].keys():
|
||||||
|
if partition_name == 'MODEL_INPUT':
|
||||||
|
self._is_input = True
|
||||||
|
else:
|
||||||
|
prev_rank = self.partition_name_to_pp_rank(partition_name)
|
||||||
|
self.producer_stage_ids.append(prev_rank)
|
||||||
|
|
||||||
|
for partition_name in DAG_node['output'].keys():
|
||||||
|
if partition_name == 'MODEL_OUTPUT':
|
||||||
|
self._is_output = True
|
||||||
|
else:
|
||||||
|
next_rank = self.partition_name_to_pp_rank(partition_name)
|
||||||
|
self.consumer_stage_ids.append(next_rank)
|
||||||
|
|
||||||
|
# TODO(jiangziyue) Consider whether this function should be protected by Lock in DAG env.
|
||||||
|
with self.producer_consumer_init_lock:
|
||||||
|
self._producer_consumer_initialized = True
|
||||||
|
self.producer_consumer_init_lock.notify_all()
|
||||||
|
|
||||||
|
# TODO(jiangziyue) Define a Class for DAG.
|
||||||
|
def pp_rank_to_partition_name(self, pp_rank: int):
|
||||||
|
prefix = 'submod_'
|
||||||
|
partition_name = prefix + str(pp_rank)
|
||||||
|
return partition_name
|
||||||
|
|
||||||
|
# TODO(jiangziyue) Define a Class for DAG.
|
||||||
|
def partition_name_to_pp_rank(self, partition_name: str) -> int:
|
||||||
|
prefix = 'submod_'
|
||||||
|
pp_rank = int(partition_name.split(prefix)[-1])
|
||||||
|
return pp_rank
|
||||||
|
|
||||||
|
def get_DAG(self):
|
||||||
|
with self.partition_condition_lock:
|
||||||
|
self.partition_condition_lock.wait_for(lambda: hasattr(self, 'module_partition'))
|
||||||
|
if hasattr(self.module_partition, '_DAG'):
|
||||||
|
return self.module_partition._DAG
|
||||||
|
else:
|
||||||
|
return None
|
||||||
|
|
||||||
|
def use_middleware(self):
|
||||||
|
DAG = self.get_DAG()
|
||||||
|
return DAG is not None
|
||||||
|
|
||||||
|
# TODO(jiangziyue) get single value instead of the whole output
|
||||||
|
def _get_real_args_kwargs(self, args_or_kwargs):
|
||||||
|
if not self.use_middleware():
|
||||||
|
args_or_kwargs = pytree_map(args_or_kwargs, fn=lambda x: x.wait(), process_types=Future)
|
||||||
|
if args_or_kwargs is not None:
|
||||||
|
if isinstance(args_or_kwargs, dict):
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
flatten_args = []
|
||||||
|
pytree_map(args_or_kwargs, fn=lambda x: flatten_args.append(x), map_all=True)
|
||||||
|
args_or_kwargs = flatten_args
|
||||||
|
else:
|
||||||
|
args_or_kwargs = pytree_map(args_or_kwargs, fn=lambda x: x.wait(), process_types=Future)
|
||||||
|
if args_or_kwargs is not None:
|
||||||
|
if isinstance(args_or_kwargs, dict):
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
flatten_args = []
|
||||||
|
if self.is_first_stage():
|
||||||
|
pytree_map(args_or_kwargs, fn=lambda x: flatten_args.append(x), map_all=True)
|
||||||
|
# TODO get by offset
|
||||||
|
else:
|
||||||
|
DAG = self.get_DAG()
|
||||||
|
producer_outputs = {}
|
||||||
|
cur_DAG_node_name = self.pp_rank_to_partition_name(self.pp_rank)
|
||||||
|
#cur_DAG_node = DAG[self.pp_rank_to_partition_name(self.pp_rank)]
|
||||||
|
for i, args_from_one_mod in enumerate(args_or_kwargs):
|
||||||
|
producer_output_offsets = []
|
||||||
|
if self.need_model_input():
|
||||||
|
if i == 0:
|
||||||
|
producer_DAG_node = DAG['input_partition']
|
||||||
|
producer_partition_name = 'MODEL_INPUT'
|
||||||
|
offset = 0
|
||||||
|
for arg_info in producer_DAG_node.values():
|
||||||
|
if cur_DAG_node_name in arg_info['output']:
|
||||||
|
producer_output_offsets.append(offset)
|
||||||
|
offset += 1
|
||||||
|
else:
|
||||||
|
producer_rank = self.producer_stage_ids[i-1]
|
||||||
|
producer_partition_name = self.pp_rank_to_partition_name(producer_rank)
|
||||||
|
producer_DAG_node = DAG[producer_partition_name]
|
||||||
|
producer_output_offsets = producer_DAG_node['output'][cur_DAG_node_name]
|
||||||
|
|
||||||
|
else:
|
||||||
|
producer_rank = self.producer_stage_ids[i]
|
||||||
|
producer_partition_name = self.pp_rank_to_partition_name(producer_rank)
|
||||||
|
producer_DAG_node = DAG[producer_partition_name]
|
||||||
|
producer_output_offsets = producer_DAG_node['output'][cur_DAG_node_name]
|
||||||
|
|
||||||
|
if producer_partition_name != 'MODEL_INPUT' and DAG[producer_partition_name]['output_len'] == 1:
|
||||||
|
producer_outputs[producer_partition_name] = [args_from_one_mod]
|
||||||
|
else:
|
||||||
|
producer_outputs[producer_partition_name] = [args_from_one_mod[offset] for offset in producer_output_offsets]
|
||||||
|
|
||||||
|
cur_DAG_node_input = DAG[cur_DAG_node_name]['input']
|
||||||
|
|
||||||
|
def get_input_len(DAG_node_input):
|
||||||
|
res = 0
|
||||||
|
for offsets in DAG_node_input.values():
|
||||||
|
res += len(offsets)
|
||||||
|
return res
|
||||||
|
|
||||||
|
input_len = get_input_len(cur_DAG_node_input)
|
||||||
|
flatten_args = [None] * input_len
|
||||||
|
for producer_partition_name, args_input_offsets in cur_DAG_node_input.items():
|
||||||
|
for i, args_input_offset in enumerate(args_input_offsets):
|
||||||
|
flatten_args[args_input_offset] = producer_outputs[producer_partition_name][i]
|
||||||
|
|
||||||
|
args_or_kwargs = flatten_args
|
||||||
|
return args_or_kwargs
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def _get_work_item_key(self) -> UniqueKey:
|
def _get_work_item_key(self) -> UniqueKey:
|
||||||
@ -354,6 +564,9 @@ class WorkerBase(ABC):
|
|||||||
def is_last_stage(self):
|
def is_last_stage(self):
|
||||||
return self.pp_rank == self.actual_stage_num - 1
|
return self.pp_rank == self.actual_stage_num - 1
|
||||||
|
|
||||||
|
def need_model_input(self):
|
||||||
|
return not self.is_first_stage() and self._is_input
|
||||||
|
|
||||||
def _default_data_process_func(self, args_kwargs):
|
def _default_data_process_func(self, args_kwargs):
|
||||||
if self.is_first_stage():
|
if self.is_first_stage():
|
||||||
args = args_kwargs[0]
|
args = args_kwargs[0]
|
||||||
@ -390,11 +603,11 @@ class WorkerBase(ABC):
|
|||||||
|
|
||||||
# parse and integrate args and kwargs
|
# parse and integrate args and kwargs
|
||||||
if is_first_stage:
|
if is_first_stage:
|
||||||
args = get_real_args_kwargs(args)
|
args = self._get_real_args_kwargs(args)
|
||||||
kwargs = get_real_args_kwargs(kwargs)
|
kwargs = self._get_real_args_kwargs(kwargs)
|
||||||
args_kwargs = (args, kwargs)
|
args_kwargs = (args, kwargs)
|
||||||
else:
|
else:
|
||||||
args_kwargs = get_real_args_kwargs(args)
|
args_kwargs = self._get_real_args_kwargs(args)
|
||||||
|
|
||||||
args, kwargs = data_process_func(args_kwargs)
|
args, kwargs = data_process_func(args_kwargs)
|
||||||
|
|
||||||
@ -486,7 +699,7 @@ class WorkerBase(ABC):
|
|||||||
|
|
||||||
# overlap recompute and future.wait
|
# overlap recompute and future.wait
|
||||||
if not is_last_stage:
|
if not is_last_stage:
|
||||||
grad_tensors = get_real_args_kwargs(args)
|
grad_tensors = self._get_real_args_kwargs(args)
|
||||||
else:
|
else:
|
||||||
grad_tensors = None
|
grad_tensors = None
|
||||||
|
|
||||||
@ -569,6 +782,9 @@ class WorkerBase(ABC):
|
|||||||
self._reset_context()
|
self._reset_context()
|
||||||
|
|
||||||
def initialize_optimizer(self, optimizer_class: type, **kwargs):
|
def initialize_optimizer(self, optimizer_class: type, **kwargs):
|
||||||
|
# TODO(jiangziyue) it's temporary code to deal with empty module partition.
|
||||||
|
# After tracer fixed, remove this part.
|
||||||
|
if len(list(self.module_partition.parameters())) > 0:
|
||||||
self.optimizer: optim.Optimizer = optimizer_class(self.module_partition.parameters(), **kwargs)
|
self.optimizer: optim.Optimizer = optimizer_class(self.module_partition.parameters(), **kwargs)
|
||||||
self.step_lock = threading.Lock()
|
self.step_lock = threading.Lock()
|
||||||
self.step_lock.acquire()
|
self.step_lock.acquire()
|
||||||
@ -577,6 +793,9 @@ class WorkerBase(ABC):
|
|||||||
self.step_lock.acquire()
|
self.step_lock.acquire()
|
||||||
|
|
||||||
def step(self):
|
def step(self):
|
||||||
|
# TODO(jiangziyue) it's temporary code to deal with empty module partition.
|
||||||
|
# After tracer fixed, remove this part.
|
||||||
|
if len(list(self.module_partition.parameters())) > 0:
|
||||||
self.optimizer.step()
|
self.optimizer.step()
|
||||||
self.optimizer.zero_grad()
|
self.optimizer.zero_grad()
|
||||||
self.step_lock.release()
|
self.step_lock.release()
|
||||||
|
@ -20,6 +20,18 @@ def color_debug(text, prefix=' ', color='blue'):
|
|||||||
color = color.upper()
|
color = color.upper()
|
||||||
print(getattr(Back, color), prefix, Style.RESET_ALL, text)
|
print(getattr(Back, color), prefix, Style.RESET_ALL, text)
|
||||||
|
|
||||||
|
class MLP(nn.Module):
|
||||||
|
def __init__(self, dim: int, layers: int):
|
||||||
|
super().__init__()
|
||||||
|
self.layers = torch.nn.ModuleList()
|
||||||
|
|
||||||
|
for _ in range(layers):
|
||||||
|
self.layers.append(nn.Linear(dim, dim, bias=False))
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
for layer in self.layers:
|
||||||
|
x = layer(x)
|
||||||
|
return x
|
||||||
|
|
||||||
class RpcTestModel(nn.Module):
|
class RpcTestModel(nn.Module):
|
||||||
|
|
||||||
|
60
tests/test_pipeline/test_middleware_1f1b.py
Normal file
60
tests/test_pipeline/test_middleware_1f1b.py
Normal file
@ -0,0 +1,60 @@
|
|||||||
|
import torch
|
||||||
|
from torch import nn
|
||||||
|
|
||||||
|
from colossalai.pipeline.rpc._pipeline_schedule import OneFOneBPipelineEngine
|
||||||
|
from colossalai.fx.passes.adding_split_node_pass import split_with_split_nodes_pass, balanced_split_pass
|
||||||
|
from colossalai.fx import ColoTracer
|
||||||
|
from rpc_test_utils import rpc_run, parse_args, MLP
|
||||||
|
from functools import partial
|
||||||
|
|
||||||
|
# global variable for model created
|
||||||
|
batch_size = 16
|
||||||
|
dim = 10
|
||||||
|
|
||||||
|
def create_partition_module(pp_rank: int, stage_num: int, model, data_kwargs):
|
||||||
|
model.eval()
|
||||||
|
tracer = ColoTracer()
|
||||||
|
meta_args = {k: v.to('meta') for k, v in data_kwargs.items()}
|
||||||
|
graph = tracer.trace(root=model, meta_args=meta_args)
|
||||||
|
gm = torch.fx.GraphModule(model, graph, model.__class__.__name__)
|
||||||
|
annotated_model = balanced_split_pass(gm, stage_num)
|
||||||
|
split_model, _ = split_with_split_nodes_pass(annotated_model, merge_output=True)
|
||||||
|
return list(split_model.children())[pp_rank]
|
||||||
|
|
||||||
|
def partition(data_kwargs: dict, pp_rank: int, chunk: int, stage_num: int):
|
||||||
|
torch.manual_seed(1024)
|
||||||
|
model = MLP(dim, stage_num * 3)
|
||||||
|
partition = create_partition_module(pp_rank, stage_num, model, data_kwargs)
|
||||||
|
return partition
|
||||||
|
|
||||||
|
def run_master(args):
|
||||||
|
torch.manual_seed(100)
|
||||||
|
|
||||||
|
epoch = args.epoch
|
||||||
|
device = args.device
|
||||||
|
stage_num = args.world_size
|
||||||
|
chunk = args.chunk
|
||||||
|
num_microbatches = args.num_microbatches
|
||||||
|
use_checkpoint = args.use_checkpoint
|
||||||
|
|
||||||
|
input_sample = torch.randn((batch_size, dim), device=device)
|
||||||
|
|
||||||
|
def data_gen():
|
||||||
|
x = torch.zeros((batch_size, dim))
|
||||||
|
kwargs = dict(x=x)
|
||||||
|
return kwargs
|
||||||
|
|
||||||
|
data_kwargs = data_gen()
|
||||||
|
engine = OneFOneBPipelineEngine(partition_fn=partial(partition, data_kwargs),
|
||||||
|
stage_num=stage_num,
|
||||||
|
num_microbatches=num_microbatches,
|
||||||
|
device=device,
|
||||||
|
chunk=chunk,
|
||||||
|
checkpoint=use_checkpoint)
|
||||||
|
|
||||||
|
for _ in range(epoch):
|
||||||
|
logits = engine.forward_backward({'x': input_sample}, forward_only=True)
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
args = parse_args()
|
||||||
|
rpc_run(args, run_master)
|
Loading…
Reference in New Issue
Block a user