mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-06-20 20:54:55 +00:00
[rpc] split with dag (#2028)
* add DAG to split_module * add comment * add test case for DAG * remove print * add DAG middleware in scheduler * add test case for scheduler * remove break * recover old lifecycle Co-authored-by: Ziyue Jiang <ziyue.jiang@gmail.com>
This commit is contained in:
parent
96134e7be3
commit
b0936e4a44
@ -117,7 +117,7 @@ def uniform_split_pass(gm: torch.fx.GraphModule, pp_size: int):
|
||||
return gm
|
||||
|
||||
|
||||
def split_with_split_nodes_pass(annotated_gm: torch.fx.GraphModule):
|
||||
def split_with_split_nodes_pass(annotated_gm: torch.fx.GraphModule, merge_output=False):
|
||||
# TODO(lyl): use partition IR to assign partition ID to each node.
|
||||
# Currently: analyzing graph -> annotate graph by inserting split node -> use split module pass to split graph
|
||||
# In future: graph to partitions -> analyzing partition IR -> recombining partitions to get best performance -> assign partition ID to each node
|
||||
@ -129,7 +129,7 @@ def split_with_split_nodes_pass(annotated_gm: torch.fx.GraphModule):
|
||||
part_idx += 1
|
||||
return part_idx
|
||||
|
||||
split_mod = split_module(annotated_gm, None, split_callback)
|
||||
split_mod = split_module(annotated_gm, None, split_callback, merge_output)
|
||||
split_submodules = []
|
||||
for name, submodule in split_mod.named_modules():
|
||||
if isinstance(submodule, torch.fx.GraphModule):
|
||||
|
@ -199,24 +199,17 @@ def find_user_in_partition(node, partitions, output_partitions=None, direct=Fals
|
||||
for partition in partitions:
|
||||
if node == partition:
|
||||
user_partition_names.append(partition.name)
|
||||
|
||||
# find user with getitem call
|
||||
else:
|
||||
for partition in partitions:
|
||||
if node in partition.args:
|
||||
user_partition_names.append(partition.name)
|
||||
|
||||
is_output = False
|
||||
def find_output(def_node, output_node):
|
||||
nonlocal is_output
|
||||
if def_node == output_node:
|
||||
is_output = True
|
||||
|
||||
|
||||
if output_partitions is not None:
|
||||
output_node = output_partitions[0]
|
||||
torch.fx.graph.map_arg(output_node.args[0], lambda n: find_output(node, n))
|
||||
|
||||
if is_output:
|
||||
user_partition_names.append('MODEL_OUTPUT')
|
||||
if node.op == output_node.op:
|
||||
user_partition_names.append('MODEL_OUTPUT')
|
||||
|
||||
if len(user_partition_names) > 0:
|
||||
return user_partition_names
|
||||
|
@ -9,7 +9,7 @@ from typing import Any, Callable, Dict, List, Tuple
|
||||
import torch
|
||||
import torch.distributed.rpc as rpc
|
||||
from colossalai.pipeline.pipeline_process_group import ppg
|
||||
from colossalai.pipeline.rpc.utils import (get_batch_lengths, get_real_args_kwargs, pytree_filter, pytree_map,
|
||||
from colossalai.pipeline.rpc.utils import (get_batch_lengths, pytree_filter, pytree_map,
|
||||
split_batch, tensor_shape_list, type_detail)
|
||||
from torch import autograd, nn, optim
|
||||
from torch._C._distributed_rpc import PyRRef
|
||||
@ -20,7 +20,7 @@ class Phase(Enum):
|
||||
FORWARD = 0
|
||||
BACKWARD = 1
|
||||
UPDATE = 2
|
||||
|
||||
INPUT = 3
|
||||
|
||||
class UniqueKey:
|
||||
__slots__ = ('microbatch_id', 'phase')
|
||||
@ -128,6 +128,7 @@ class WorkerBase(ABC):
|
||||
# topology info
|
||||
self.producer_stage_ids: List[int] = None
|
||||
self.consumer_stage_ids: List[int] = None
|
||||
self.input_consumer_stage_ids: List[int] = None
|
||||
|
||||
# module partitions
|
||||
self.partition_fn = partition_fn
|
||||
@ -135,6 +136,11 @@ class WorkerBase(ABC):
|
||||
self.criterion = criterion
|
||||
self.metric = metric
|
||||
|
||||
# middleware info
|
||||
self._is_input = False
|
||||
self._is_output = False
|
||||
self._producer_consumer_initialized = False
|
||||
|
||||
# context to maintain loop
|
||||
self._initialize_context_container()
|
||||
|
||||
@ -164,6 +170,7 @@ class WorkerBase(ABC):
|
||||
self.work_list_condition_lock = threading.Condition(threading.Lock())
|
||||
self.output_list_condition_lock = threading.Condition(threading.Lock())
|
||||
self.label_lock = threading.Condition(threading.Lock())
|
||||
self.producer_consumer_init_lock = threading.Condition(threading.Lock())
|
||||
|
||||
def _initialize_partition(self):
|
||||
partition_fn = self.partition_fn
|
||||
@ -182,7 +189,7 @@ class WorkerBase(ABC):
|
||||
# construction of partition is executed after the registion of pp_rank_to_worker_rref
|
||||
self._initialize_partition()
|
||||
|
||||
def get_output_by_key(self, key: UniqueKey) -> Any:
|
||||
def get_output_by_key(self, key: UniqueKey, recv_rank=None) -> Any:
|
||||
with self.output_list_condition_lock:
|
||||
self.output_list_condition_lock.wait_for(lambda: key in self.output_list)
|
||||
output_work_item = self.output_list[key]
|
||||
@ -191,8 +198,9 @@ class WorkerBase(ABC):
|
||||
if isinstance(output, Future):
|
||||
output = output.wait()
|
||||
|
||||
output_work_item.refcount += 1
|
||||
# output_work_item.refcount += 1
|
||||
|
||||
# TODO(jiangziyue) redesign lifecycle management for DAG scheduler
|
||||
# all consumers have been satisfied, the work_item can be released
|
||||
with self.output_list_condition_lock:
|
||||
if output_work_item.refcount >= len(self.consumer_stage_ids):
|
||||
@ -215,8 +223,10 @@ class WorkerBase(ABC):
|
||||
self.partition_condition_lock.wait_for(lambda: hasattr(self, 'module_partition'))
|
||||
return self.module_partition.state_dict()
|
||||
|
||||
def _make_args_kwargs(self, microbatch):
|
||||
def _make_args_kwargs(self, microbatch, merge=False):
|
||||
if isinstance(microbatch, dict):
|
||||
if merge:
|
||||
return list(microbatch.values()), {}
|
||||
return [], microbatch
|
||||
elif isinstance(microbatch, torch.Tensor):
|
||||
return [microbatch], {}
|
||||
@ -228,24 +238,70 @@ class WorkerBase(ABC):
|
||||
kwargs.update(arg)
|
||||
else:
|
||||
args.append(arg)
|
||||
if merge:
|
||||
arg_lst = args
|
||||
for arg in kwargs.values():
|
||||
arg_lst.append(arg)
|
||||
return arg_lst, {}
|
||||
return args, kwargs
|
||||
else:
|
||||
raise TypeError(f"Input batch can be only dict, list, tuple or tensor, but receive {type(microbatch)}")
|
||||
|
||||
# just for first pp_rank
|
||||
# TODO(jiangziyue) Consider whether this function should be protected by Lock in DAG env.
|
||||
# TODO(jiangziyue) Define a Class for DAG.
|
||||
def set_input(self, microbatch_id: int, microbatch: Tuple[Any], forward_only: bool):
|
||||
assert self.consumer_stage_ids is not None
|
||||
key = UniqueKey(microbatch_id, Phase.FORWARD)
|
||||
output = self._get_future_by_device()
|
||||
|
||||
if not self.use_middleware():
|
||||
# make args and kwargs
|
||||
args, kwargs = self._make_args_kwargs(microbatch)
|
||||
|
||||
# make args and kwargs
|
||||
args, kwargs = self._make_args_kwargs(microbatch)
|
||||
work_item = WorkItem(self.pp_rank, Phase.FORWARD, args, kwargs, output, microbatch_id, None,
|
||||
self.num_microbatches, forward_only)
|
||||
with self.work_list_condition_lock:
|
||||
self.work_list[key] = work_item
|
||||
self.work_list_condition_lock.notify_all()
|
||||
else:
|
||||
# make args and kwargs
|
||||
arg_lst, _ = self._make_args_kwargs(microbatch, merge=True)
|
||||
|
||||
work_item = WorkItem(self.pp_rank, Phase.FORWARD, args, kwargs, output, microbatch_id, None,
|
||||
self.num_microbatches, forward_only)
|
||||
with self.work_list_condition_lock:
|
||||
self.work_list[key] = work_item
|
||||
self.work_list_condition_lock.notify_all()
|
||||
# first stage assign correct input into other stages
|
||||
DAG = self.get_DAG()
|
||||
DAG_node = DAG['input_partition']
|
||||
self_input_offsets = []
|
||||
recv_input_key = UniqueKey(microbatch_id, Phase.INPUT)
|
||||
# notify rank which should receive extra input
|
||||
offset = 0
|
||||
for details in DAG_node.values():
|
||||
for partition_name in details['output'].keys():
|
||||
recv_rank = self.partition_name_to_pp_rank(partition_name)
|
||||
if recv_rank == self.pp_rank:
|
||||
self_input_offsets.append(offset)
|
||||
elif recv_rank not in self.input_consumer_stage_ids:
|
||||
self.input_consumer_stage_ids.append(recv_rank)
|
||||
offset += 1
|
||||
|
||||
# set input for self rank
|
||||
self_arg_lst = []
|
||||
for off in self_input_offsets:
|
||||
self_arg_lst.append(arg_lst[off])
|
||||
|
||||
work_item = WorkItem(self.pp_rank, Phase.FORWARD, self_arg_lst, {}, output, microbatch_id, None,
|
||||
self.num_microbatches, forward_only)
|
||||
with self.work_list_condition_lock:
|
||||
self.work_list[key] = work_item
|
||||
self.work_list_condition_lock.notify_all()
|
||||
|
||||
# put input tensor which other nodes need into output_list
|
||||
work_item_remote = WorkItem(self.pp_rank, Phase.INPUT, [], {}, arg_lst, microbatch_id, None,
|
||||
self.num_microbatches, forward_only)
|
||||
|
||||
with self.output_list_condition_lock:
|
||||
self.output_list[recv_input_key] = work_item_remote
|
||||
self.output_list_condition_lock.notify_all()
|
||||
|
||||
# just for last pp_rank
|
||||
def set_labels(self, microbatch_id: int, microlabels: Any):
|
||||
@ -268,33 +324,68 @@ class WorkerBase(ABC):
|
||||
self.work_list[key] = work_item
|
||||
self.work_list_condition_lock.notify_all()
|
||||
|
||||
# TODO(jiangziyue) Consider whether this function should be protected by Lock in DAG env.
|
||||
def subscribe_producer(self, microbatch_id: int, forward_only: bool):
|
||||
"""
|
||||
You should call this function asynchronously
|
||||
"""
|
||||
assert self.producer_stage_ids is not None
|
||||
producer_num = len(self.producer_stage_ids)
|
||||
assert producer_num > 0, "only stage that has producers can subscribe producers"
|
||||
|
||||
stage_id = self.pp_rank
|
||||
subscribe_forward_futures: List[Future] = [None] * producer_num
|
||||
output = self._get_future_by_device()
|
||||
if not self.use_middleware():
|
||||
producer_num = len(self.producer_stage_ids)
|
||||
subscribe_forward_futures: List[Future] = [None] * producer_num
|
||||
for i in range(producer_num):
|
||||
producer_stage_id = self.producer_stage_ids[i]
|
||||
producer_output_key = UniqueKey(microbatch_id, Phase.FORWARD)
|
||||
producer_worker_rref = self.pp_rank_to_worker_rref[producer_stage_id]
|
||||
subscribe_forward_futures[i] = producer_worker_rref.rpc_async().get_output_by_key(producer_output_key)
|
||||
else:
|
||||
with self.work_list_condition_lock:
|
||||
key = UniqueKey(microbatch_id, Phase.FORWARD)
|
||||
if key in self.work_list:
|
||||
return
|
||||
|
||||
for i in range(producer_num):
|
||||
producer_stage_id = self.producer_stage_ids[i]
|
||||
producer_output_key = UniqueKey(microbatch_id, Phase.FORWARD)
|
||||
producer_worker_rref = self.pp_rank_to_worker_rref[producer_stage_id]
|
||||
subscribe_forward_futures[i] = producer_worker_rref.rpc_async().get_output_by_key(producer_output_key)
|
||||
producer_stage_ids = []
|
||||
with self.producer_consumer_init_lock:
|
||||
self.producer_consumer_init_lock.wait_for(lambda: self._producer_consumer_initialized)
|
||||
producer_stage_ids = self.producer_stage_ids
|
||||
producer_num = len(producer_stage_ids)
|
||||
|
||||
# TODO(jiangziyue) get single value instead of the whole output
|
||||
if self.need_model_input():
|
||||
producer_num += 1 # extra one(the last one) for input_tensor
|
||||
subscribe_forward_futures: List[Future] = [None] * producer_num
|
||||
|
||||
# TODO(jiangziyue) get single value instead of the whole output
|
||||
if self.need_model_input():
|
||||
producer_stage_id = 0
|
||||
producer_output_key = UniqueKey(microbatch_id, Phase.INPUT)
|
||||
producer_worker_rref = self.pp_rank_to_worker_rref[producer_stage_id]
|
||||
subscribe_forward_futures[0] = producer_worker_rref.rpc_async().get_output_by_key(producer_output_key, self.pp_rank)
|
||||
|
||||
for i in range(0, producer_num-1):
|
||||
producer_stage_id = producer_stage_ids[i]
|
||||
producer_output_key = UniqueKey(microbatch_id, Phase.FORWARD)
|
||||
producer_worker_rref = self.pp_rank_to_worker_rref[producer_stage_id]
|
||||
subscribe_forward_futures[i+1] = producer_worker_rref.rpc_async().get_output_by_key(producer_output_key, self.pp_rank)
|
||||
|
||||
else:
|
||||
for i in range(producer_num):
|
||||
producer_stage_id = producer_stage_ids[i]
|
||||
producer_output_key = UniqueKey(microbatch_id, Phase.FORWARD)
|
||||
producer_worker_rref = self.pp_rank_to_worker_rref[producer_stage_id]
|
||||
#producer_partition_name = self.pp_rank_to_partition_name[producer_stage_id]
|
||||
subscribe_forward_futures[i] = producer_worker_rref.rpc_async().get_output_by_key(producer_output_key, self.pp_rank)
|
||||
|
||||
work_item_from_producer = WorkItem(stage_id, Phase.FORWARD, subscribe_forward_futures, {}, output,
|
||||
microbatch_id, None, self.num_microbatches, forward_only)
|
||||
microbatch_id, None, self.num_microbatches, forward_only)
|
||||
|
||||
# add work_item to work_list
|
||||
with self.work_list_condition_lock:
|
||||
key = UniqueKey(microbatch_id, Phase.FORWARD)
|
||||
assert key not in self.work_list
|
||||
self.work_list[key] = work_item_from_producer
|
||||
self.work_list_condition_lock.notify_all()
|
||||
if key not in self.work_list:
|
||||
self.work_list[key] = work_item_from_producer
|
||||
self.work_list_condition_lock.notify_all()
|
||||
|
||||
def subscribe_consumer(self, microbatch_id: int):
|
||||
"""
|
||||
@ -334,13 +425,132 @@ class WorkerBase(ABC):
|
||||
self.producer_stage_ids = []
|
||||
self.consumer_stage_ids = []
|
||||
|
||||
# Just for demo
|
||||
prev_rank = rank - 1
|
||||
next_rank = rank + 1
|
||||
if prev_rank >= 0:
|
||||
self.producer_stage_ids.append(prev_rank)
|
||||
if next_rank <= self.actual_stage_num - 1:
|
||||
self.consumer_stage_ids.append(next_rank)
|
||||
if not self.use_middleware():
|
||||
# Just for demo
|
||||
prev_rank = rank - 1
|
||||
next_rank = rank + 1
|
||||
if prev_rank >= 0:
|
||||
self.producer_stage_ids.append(prev_rank)
|
||||
if next_rank <= self.actual_stage_num - 1:
|
||||
self.consumer_stage_ids.append(next_rank)
|
||||
else:
|
||||
self.input_consumer_stage_ids = []
|
||||
DAG = self.get_DAG()
|
||||
DAG_node_name = self.pp_rank_to_partition_name(rank)
|
||||
DAG_node = DAG[DAG_node_name]
|
||||
for partition_name in DAG_node['input'].keys():
|
||||
if partition_name == 'MODEL_INPUT':
|
||||
self._is_input = True
|
||||
else:
|
||||
prev_rank = self.partition_name_to_pp_rank(partition_name)
|
||||
self.producer_stage_ids.append(prev_rank)
|
||||
|
||||
for partition_name in DAG_node['output'].keys():
|
||||
if partition_name == 'MODEL_OUTPUT':
|
||||
self._is_output = True
|
||||
else:
|
||||
next_rank = self.partition_name_to_pp_rank(partition_name)
|
||||
self.consumer_stage_ids.append(next_rank)
|
||||
|
||||
# TODO(jiangziyue) Consider whether this function should be protected by Lock in DAG env.
|
||||
with self.producer_consumer_init_lock:
|
||||
self._producer_consumer_initialized = True
|
||||
self.producer_consumer_init_lock.notify_all()
|
||||
|
||||
# TODO(jiangziyue) Define a Class for DAG.
|
||||
def pp_rank_to_partition_name(self, pp_rank: int):
|
||||
prefix = 'submod_'
|
||||
partition_name = prefix + str(pp_rank)
|
||||
return partition_name
|
||||
|
||||
# TODO(jiangziyue) Define a Class for DAG.
|
||||
def partition_name_to_pp_rank(self, partition_name: str) -> int:
|
||||
prefix = 'submod_'
|
||||
pp_rank = int(partition_name.split(prefix)[-1])
|
||||
return pp_rank
|
||||
|
||||
def get_DAG(self):
|
||||
with self.partition_condition_lock:
|
||||
self.partition_condition_lock.wait_for(lambda: hasattr(self, 'module_partition'))
|
||||
if hasattr(self.module_partition, '_DAG'):
|
||||
return self.module_partition._DAG
|
||||
else:
|
||||
return None
|
||||
|
||||
def use_middleware(self):
|
||||
DAG = self.get_DAG()
|
||||
return DAG is not None
|
||||
|
||||
# TODO(jiangziyue) get single value instead of the whole output
|
||||
def _get_real_args_kwargs(self, args_or_kwargs):
|
||||
if not self.use_middleware():
|
||||
args_or_kwargs = pytree_map(args_or_kwargs, fn=lambda x: x.wait(), process_types=Future)
|
||||
if args_or_kwargs is not None:
|
||||
if isinstance(args_or_kwargs, dict):
|
||||
pass
|
||||
else:
|
||||
flatten_args = []
|
||||
pytree_map(args_or_kwargs, fn=lambda x: flatten_args.append(x), map_all=True)
|
||||
args_or_kwargs = flatten_args
|
||||
else:
|
||||
args_or_kwargs = pytree_map(args_or_kwargs, fn=lambda x: x.wait(), process_types=Future)
|
||||
if args_or_kwargs is not None:
|
||||
if isinstance(args_or_kwargs, dict):
|
||||
pass
|
||||
else:
|
||||
flatten_args = []
|
||||
if self.is_first_stage():
|
||||
pytree_map(args_or_kwargs, fn=lambda x: flatten_args.append(x), map_all=True)
|
||||
# TODO get by offset
|
||||
else:
|
||||
DAG = self.get_DAG()
|
||||
producer_outputs = {}
|
||||
cur_DAG_node_name = self.pp_rank_to_partition_name(self.pp_rank)
|
||||
#cur_DAG_node = DAG[self.pp_rank_to_partition_name(self.pp_rank)]
|
||||
for i, args_from_one_mod in enumerate(args_or_kwargs):
|
||||
producer_output_offsets = []
|
||||
if self.need_model_input():
|
||||
if i == 0:
|
||||
producer_DAG_node = DAG['input_partition']
|
||||
producer_partition_name = 'MODEL_INPUT'
|
||||
offset = 0
|
||||
for arg_info in producer_DAG_node.values():
|
||||
if cur_DAG_node_name in arg_info['output']:
|
||||
producer_output_offsets.append(offset)
|
||||
offset += 1
|
||||
else:
|
||||
producer_rank = self.producer_stage_ids[i-1]
|
||||
producer_partition_name = self.pp_rank_to_partition_name(producer_rank)
|
||||
producer_DAG_node = DAG[producer_partition_name]
|
||||
producer_output_offsets = producer_DAG_node['output'][cur_DAG_node_name]
|
||||
|
||||
else:
|
||||
producer_rank = self.producer_stage_ids[i]
|
||||
producer_partition_name = self.pp_rank_to_partition_name(producer_rank)
|
||||
producer_DAG_node = DAG[producer_partition_name]
|
||||
producer_output_offsets = producer_DAG_node['output'][cur_DAG_node_name]
|
||||
|
||||
if producer_partition_name != 'MODEL_INPUT' and DAG[producer_partition_name]['output_len'] == 1:
|
||||
producer_outputs[producer_partition_name] = [args_from_one_mod]
|
||||
else:
|
||||
producer_outputs[producer_partition_name] = [args_from_one_mod[offset] for offset in producer_output_offsets]
|
||||
|
||||
cur_DAG_node_input = DAG[cur_DAG_node_name]['input']
|
||||
|
||||
def get_input_len(DAG_node_input):
|
||||
res = 0
|
||||
for offsets in DAG_node_input.values():
|
||||
res += len(offsets)
|
||||
return res
|
||||
|
||||
input_len = get_input_len(cur_DAG_node_input)
|
||||
flatten_args = [None] * input_len
|
||||
for producer_partition_name, args_input_offsets in cur_DAG_node_input.items():
|
||||
for i, args_input_offset in enumerate(args_input_offsets):
|
||||
flatten_args[args_input_offset] = producer_outputs[producer_partition_name][i]
|
||||
|
||||
args_or_kwargs = flatten_args
|
||||
return args_or_kwargs
|
||||
|
||||
@abstractmethod
|
||||
def _get_work_item_key(self) -> UniqueKey:
|
||||
@ -353,6 +563,9 @@ class WorkerBase(ABC):
|
||||
|
||||
def is_last_stage(self):
|
||||
return self.pp_rank == self.actual_stage_num - 1
|
||||
|
||||
def need_model_input(self):
|
||||
return not self.is_first_stage() and self._is_input
|
||||
|
||||
def _default_data_process_func(self, args_kwargs):
|
||||
if self.is_first_stage():
|
||||
@ -390,11 +603,11 @@ class WorkerBase(ABC):
|
||||
|
||||
# parse and integrate args and kwargs
|
||||
if is_first_stage:
|
||||
args = get_real_args_kwargs(args)
|
||||
kwargs = get_real_args_kwargs(kwargs)
|
||||
args = self._get_real_args_kwargs(args)
|
||||
kwargs = self._get_real_args_kwargs(kwargs)
|
||||
args_kwargs = (args, kwargs)
|
||||
else:
|
||||
args_kwargs = get_real_args_kwargs(args)
|
||||
args_kwargs = self._get_real_args_kwargs(args)
|
||||
|
||||
args, kwargs = data_process_func(args_kwargs)
|
||||
|
||||
@ -486,7 +699,7 @@ class WorkerBase(ABC):
|
||||
|
||||
# overlap recompute and future.wait
|
||||
if not is_last_stage:
|
||||
grad_tensors = get_real_args_kwargs(args)
|
||||
grad_tensors = self._get_real_args_kwargs(args)
|
||||
else:
|
||||
grad_tensors = None
|
||||
|
||||
@ -569,7 +782,10 @@ class WorkerBase(ABC):
|
||||
self._reset_context()
|
||||
|
||||
def initialize_optimizer(self, optimizer_class: type, **kwargs):
|
||||
self.optimizer: optim.Optimizer = optimizer_class(self.module_partition.parameters(), **kwargs)
|
||||
# TODO(jiangziyue) it's temporary code to deal with empty module partition.
|
||||
# After tracer fixed, remove this part.
|
||||
if len(list(self.module_partition.parameters())) > 0:
|
||||
self.optimizer: optim.Optimizer = optimizer_class(self.module_partition.parameters(), **kwargs)
|
||||
self.step_lock = threading.Lock()
|
||||
self.step_lock.acquire()
|
||||
|
||||
@ -577,8 +793,11 @@ class WorkerBase(ABC):
|
||||
self.step_lock.acquire()
|
||||
|
||||
def step(self):
|
||||
self.optimizer.step()
|
||||
self.optimizer.zero_grad()
|
||||
# TODO(jiangziyue) it's temporary code to deal with empty module partition.
|
||||
# After tracer fixed, remove this part.
|
||||
if len(list(self.module_partition.parameters())) > 0:
|
||||
self.optimizer.step()
|
||||
self.optimizer.zero_grad()
|
||||
self.step_lock.release()
|
||||
|
||||
|
||||
|
@ -20,6 +20,18 @@ def color_debug(text, prefix=' ', color='blue'):
|
||||
color = color.upper()
|
||||
print(getattr(Back, color), prefix, Style.RESET_ALL, text)
|
||||
|
||||
class MLP(nn.Module):
|
||||
def __init__(self, dim: int, layers: int):
|
||||
super().__init__()
|
||||
self.layers = torch.nn.ModuleList()
|
||||
|
||||
for _ in range(layers):
|
||||
self.layers.append(nn.Linear(dim, dim, bias=False))
|
||||
|
||||
def forward(self, x):
|
||||
for layer in self.layers:
|
||||
x = layer(x)
|
||||
return x
|
||||
|
||||
class RpcTestModel(nn.Module):
|
||||
|
||||
|
60
tests/test_pipeline/test_middleware_1f1b.py
Normal file
60
tests/test_pipeline/test_middleware_1f1b.py
Normal file
@ -0,0 +1,60 @@
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from colossalai.pipeline.rpc._pipeline_schedule import OneFOneBPipelineEngine
|
||||
from colossalai.fx.passes.adding_split_node_pass import split_with_split_nodes_pass, balanced_split_pass
|
||||
from colossalai.fx import ColoTracer
|
||||
from rpc_test_utils import rpc_run, parse_args, MLP
|
||||
from functools import partial
|
||||
|
||||
# global variable for model created
|
||||
batch_size = 16
|
||||
dim = 10
|
||||
|
||||
def create_partition_module(pp_rank: int, stage_num: int, model, data_kwargs):
|
||||
model.eval()
|
||||
tracer = ColoTracer()
|
||||
meta_args = {k: v.to('meta') for k, v in data_kwargs.items()}
|
||||
graph = tracer.trace(root=model, meta_args=meta_args)
|
||||
gm = torch.fx.GraphModule(model, graph, model.__class__.__name__)
|
||||
annotated_model = balanced_split_pass(gm, stage_num)
|
||||
split_model, _ = split_with_split_nodes_pass(annotated_model, merge_output=True)
|
||||
return list(split_model.children())[pp_rank]
|
||||
|
||||
def partition(data_kwargs: dict, pp_rank: int, chunk: int, stage_num: int):
|
||||
torch.manual_seed(1024)
|
||||
model = MLP(dim, stage_num * 3)
|
||||
partition = create_partition_module(pp_rank, stage_num, model, data_kwargs)
|
||||
return partition
|
||||
|
||||
def run_master(args):
|
||||
torch.manual_seed(100)
|
||||
|
||||
epoch = args.epoch
|
||||
device = args.device
|
||||
stage_num = args.world_size
|
||||
chunk = args.chunk
|
||||
num_microbatches = args.num_microbatches
|
||||
use_checkpoint = args.use_checkpoint
|
||||
|
||||
input_sample = torch.randn((batch_size, dim), device=device)
|
||||
|
||||
def data_gen():
|
||||
x = torch.zeros((batch_size, dim))
|
||||
kwargs = dict(x=x)
|
||||
return kwargs
|
||||
|
||||
data_kwargs = data_gen()
|
||||
engine = OneFOneBPipelineEngine(partition_fn=partial(partition, data_kwargs),
|
||||
stage_num=stage_num,
|
||||
num_microbatches=num_microbatches,
|
||||
device=device,
|
||||
chunk=chunk,
|
||||
checkpoint=use_checkpoint)
|
||||
|
||||
for _ in range(epoch):
|
||||
logits = engine.forward_backward({'x': input_sample}, forward_only=True)
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = parse_args()
|
||||
rpc_run(args, run_master)
|
Loading…
Reference in New Issue
Block a user