[misc] update pre-commit and run all files (#4752)

* [misc] update pre-commit

* [misc] run pre-commit

* [misc] remove useless configuration files

* [misc] ignore cuda for clang-format
This commit is contained in:
Hongxin Liu
2023-09-19 14:20:26 +08:00
committed by GitHub
parent 3c6b831c26
commit 079bf3cb26
1268 changed files with 50037 additions and 38444 deletions

View File

@@ -1,4 +1,4 @@
from .layer_spec import LayerSpec
from .pipelinable import PipelinableContext, PipelinableModel
__all__ = ['PipelinableModel', 'PipelinableContext', 'LayerSpec']
__all__ = ["PipelinableModel", "PipelinableContext", "LayerSpec"]

View File

@@ -4,9 +4,7 @@ from colossalai.utils.model.utils import call_to_str
class LayerSpec:
"""
"""
""" """
def __init__(self, typename, *module_args, **module_kwargs):
self.typename = typename
@@ -16,7 +14,7 @@ class LayerSpec:
self._param_count = 0
if not issubclass(typename, torch.nn.Module):
raise RuntimeError('LayerSpec only supports torch.nn.Module types.')
raise RuntimeError("LayerSpec only supports torch.nn.Module types.")
def __repr__(self):
return call_to_str(self.typename.__name__, self.module_args, self.module_kwargs)

View File

@@ -1,3 +1,3 @@
from .topo import Partition, PartitionInputVal, PartitionOutputVal, Topo
__all__ = ['Topo', 'Partition', 'PartitionOutputVal', 'PartitionInputVal']
__all__ = ["Topo", "Partition", "PartitionOutputVal", "PartitionInputVal"]

View File

@@ -1,3 +1,3 @@
from .fx import get_topology as get_fx_topology
__all__ = ['get_fx_topology']
__all__ = ["get_fx_topology"]

View File

@@ -10,7 +10,7 @@ def partition_name_to_id(partition_name, is_input=False, is_output=False):
elif is_output:
partition_id = 1
else:
prefix = 'submod_'
prefix = "submod_"
partition_id = int(partition_name.split(prefix)[-1]) + 2
return partition_id
@@ -27,10 +27,10 @@ def partition_name_to_id(partition_name, is_input=False, is_output=False):
def find_input_in_partition(node, partitions, input_partitions=None):
p_input_val = None
direct_def = not node.name.startswith('getitem')
direct_def = not node.name.startswith("getitem")
# search in input
if direct_def and input_partitions is not None:
partition_id = partition_name_to_id('', is_input=True)
partition_id = partition_name_to_id("", is_input=True)
for i, input_node in enumerate(input_partitions):
if input_node == node:
p_input_val = PartitionInputVal(partition_id=partition_id, offset=i)
@@ -57,7 +57,7 @@ def find_input_in_partition(node, partitions, input_partitions=None):
def find_output_in_partition(node, partitions, output_partitions=None):
p_output_val = PartitionOutputVal()
for user in node.users:
direct_use = not user.name.startswith('getitem')
direct_use = not user.name.startswith("getitem")
# user is mid partition
for partition in partitions:
# direct call
@@ -82,7 +82,7 @@ def find_output_in_partition(node, partitions, output_partitions=None):
output_node = output_partitions[0]
if user.op == output_node.op:
output_keys = {}
partition_id = partition_name_to_id('', is_output=True)
partition_id = partition_name_to_id("", is_output=True)
torch.fx.graph.map_arg(output_node.args[0], lambda n: output_keys.setdefault(n))
for i, arg in enumerate(output_keys):
if arg == node:
@@ -99,11 +99,11 @@ def get_topology(gm: GraphModule):
partitions = []
output_partitions = []
for node in gm.graph.nodes:
if node.op == 'placeholder':
if node.op == "placeholder":
input_partitions.append(node)
elif node.name.startswith('submod_'):
elif node.name.startswith("submod_"):
partitions.append(node)
elif node.op == 'output':
elif node.op == "output":
output_partitions.append(node)
else:
continue
@@ -127,7 +127,7 @@ def get_topology(gm: GraphModule):
# set output for submodule
direct_use = True
for user in partition.users:
if user.name.startswith('getitem'):
if user.name.startswith("getitem"):
direct_use = False
break
if direct_use:
@@ -146,7 +146,8 @@ def get_topology(gm: GraphModule):
topo_output_partition = Partition()
torch.fx.graph.map_arg(
partition.args[0],
lambda n: topo_output_partition.add_input_val(find_input_in_partition(n, partitions, input_partitions)))
lambda n: topo_output_partition.add_input_val(find_input_in_partition(n, partitions, input_partitions)),
)
topo.set_partitions(partition_id=1, partition=topo_output_partition)
topo.set_output_partition_id(partition_id=1)

View File

@@ -10,7 +10,7 @@ class ValPosition:
offset: int
def __str__(self) -> str:
res = f'[partition_id:{self.partition_id},offset:{self.offset}]'
res = f"[partition_id:{self.partition_id},offset:{self.offset}]"
return res
def __repr__(self) -> str:
@@ -18,7 +18,6 @@ class ValPosition:
class PartitionInputVal(object):
def __init__(self, partition_id, offset) -> None:
# every input from which partition_id and which offset
val_pos = ValPosition(partition_id, offset)
@@ -28,8 +27,8 @@ class PartitionInputVal(object):
return self._from_partition_and_offset
def __str__(self) -> str:
res = ''
res += f'<-({self._from_partition_and_offset})'
res = ""
res += f"<-({self._from_partition_and_offset})"
return res
def __repr__(self) -> str:
@@ -37,7 +36,6 @@ class PartitionInputVal(object):
class PartitionOutputVal(object):
def __init__(self) -> None:
# every output to which partition_id and which offset
self._to_partition_and_offset: List[ValPosition] = []
@@ -50,11 +48,11 @@ class PartitionOutputVal(object):
return self._to_partition_and_offset
def __str__(self) -> str:
res = ''
res += '->('
res = ""
res += "->("
for val_pos in self._to_partition_and_offset:
res += f'{val_pos},'
res += ')'
res += f"{val_pos},"
res += ")"
return res
def __repr__(self) -> str:
@@ -62,7 +60,6 @@ class PartitionOutputVal(object):
class Partition(object):
def __init__(self) -> None:
self._input_vals: List[PartitionInputVal] = []
self._output_vals: List[PartitionOutputVal] = []
@@ -110,16 +107,16 @@ class Partition(object):
return res
def __str__(self) -> str:
res = ''
res += f' input:\n'
res += f' length:{len(self._input_vals)}\n'
res = ""
res += f" input:\n"
res += f" length:{len(self._input_vals)}\n"
for i, input_val in enumerate(self._input_vals):
res += f' offset={i}:{input_val}\n'
res += f" offset={i}:{input_val}\n"
res += f' output:\n'
res += f' length:{len(self._output_vals)}\n'
res += f" output:\n"
res += f" length:{len(self._output_vals)}\n"
for i, output_val in enumerate(self._output_vals):
res += f' offset={i}:{output_val}\n'
res += f" offset={i}:{output_val}\n"
return res
@@ -140,7 +137,6 @@ class Partition(object):
# _input_partition_id: the key represents input_partition
# _output_partition_id: the key represents output_partition
class Topo(object):
def __init__(self, input_partition_id=None, output_partition_id=None) -> None:
self._partitions: Dict[int, Partition] = {}
self._input_partition_id = input_partition_id
@@ -162,7 +158,7 @@ class Topo(object):
self._partitions[partition_id] = partition
def get_mid_partitions(self):
res = {} #{partition_id: Partition}
res = {} # {partition_id: Partition}
for partition_id, partition in self._partitions.items():
if self._input_partition_id == partition_id or self._output_partition_id == partition_id:
continue
@@ -186,27 +182,27 @@ class Topo(object):
return self._partitions[partition_id]
def __str__(self) -> str:
res = ''
res = ""
if len(self._partitions) == 0:
return 'Empty Topo Graph.'
return "Empty Topo Graph."
input_part = self.get_input_partition()
if input_part is not None:
res += '{\n'
res += f'InputPartition:\n partition_id={self._input_partition_id}\n{input_part}'
res += '}\n'
res += "{\n"
res += f"InputPartition:\n partition_id={self._input_partition_id}\n{input_part}"
res += "}\n"
mid_parts = self.get_mid_partitions()
for i, (partition_id, part) in enumerate(mid_parts.items()):
res += '{\n'
res += f'SubPartition_{i}:\n partition_id={partition_id}\n {part}'
res += '}\n'
res += "{\n"
res += f"SubPartition_{i}:\n partition_id={partition_id}\n {part}"
res += "}\n"
output_part = self.get_output_partition()
if output_part is not None:
res += '{\n'
res += f'OutputPartition:\n partition_id={self._output_partition_id}\n{output_part}'
res += '}\n'
res += "{\n"
res += f"OutputPartition:\n partition_id={self._output_partition_id}\n{output_part}"
res += "}\n"
return res

View File

@@ -132,8 +132,8 @@ class PipelinableContext(InsertPostInitMethodToModuleSubClasses):
for child in self._root_children:
layer_spec = self._layer_spec_dict[id(child)]
if layer_spec.typename in (
torch.nn.modules.container.ModuleList,
torch.nn.modules.container.Sequential,
torch.nn.modules.container.ModuleList,
torch.nn.modules.container.Sequential,
):
for child_in_container in layer_spec.children:
self._layer_spec_list.append(self._layer_spec_dict[id(child_in_container)])
@@ -198,8 +198,9 @@ class PipelinableContext(InsertPostInitMethodToModuleSubClasses):
param_counts.append(layer_spec.count_params())
parts = partition_balanced(param_counts, pipeline_size, num_chunks)[rank]
elif self._policy == "customized":
assert (self._exec_seq
is not None), f"An explicit exec_seq must be defined by user in customized policy mode."
assert (
self._exec_seq is not None
), f"An explicit exec_seq must be defined by user in customized policy mode."
self.customized_parts = customized_partition(self._exec_seq)
assert len(self.customized_parts) == gpc.get_world_size(
ParallelMode.PIPELINE
@@ -226,14 +227,14 @@ class PipelinableContext(InsertPostInitMethodToModuleSubClasses):
elif (layer, "behind") in self._func_dict:
behind_func_dict_in_partition[id(module)] = self._func_dict[(layer, "behind")]
module_list_in_partition = torch.nn.ModuleList(module_list_in_partition)
pipeline_model = PipelinableModel(module_list_in_partition, front_func_dict_in_partition,
behind_func_dict_in_partition)
pipeline_model = PipelinableModel(
module_list_in_partition, front_func_dict_in_partition, behind_func_dict_in_partition
)
return pipeline_model
class PipelinableModel(torch.nn.Module):
def __init__(self, module_list, front_func_dict, behind_func_dict):
super().__init__()
self._module_list = module_list

View File

@@ -1,6 +1,5 @@
import os
import threading
from typing import Dict, List, Tuple
from typing import List
import torch.distributed as dist
from torch.distributed import rpc
@@ -14,14 +13,15 @@ class PipelineProcessGroup:
def __init__(self) -> None:
self.is_initialize = False
def set_global_info(self,
rank: int,
world_size: int,
dp_degree: int = 1,
tp_degree: int = 1,
num_worker_threads: int = 1,
device: str = "cuda") -> None:
def set_global_info(
self,
rank: int,
world_size: int,
dp_degree: int = 1,
tp_degree: int = 1,
num_worker_threads: int = 1,
device: str = "cuda",
) -> None:
device_mesh_size = dp_degree * tp_degree
assert world_size % device_mesh_size == 0, "world_size must be the multiple of dp_degree * tp_degree !!!"
self._num_worker_threads = num_worker_threads
@@ -60,8 +60,8 @@ class PipelineProcessGroup:
device = self.device
world_size = self.get_world_size()
rank = self.get_global_rank()
backend = 'nccl' if device == 'cuda' else 'gloo'
dist.init_process_group(backend, world_size=world_size, rank=rank, group_name='main_group')
backend = "nccl" if device == "cuda" else "gloo"
dist.init_process_group(backend, world_size=world_size, rank=rank, group_name="main_group")
def _initialize_pp_process_group(self) -> None:
rank = self.get_global_rank()
@@ -71,9 +71,9 @@ class PipelineProcessGroup:
options = rpc.TensorPipeRpcBackendOptions(num_worker_threads=self._num_worker_threads)
for pp_rank in self._pp_ranks:
options.set_device_map(f'work{pp_rank}', {rank: pp_rank})
options.set_device_map(f"work{pp_rank}", {rank: pp_rank})
rpc.init_rpc(name=f'work{rank}', rank=rank, world_size=world_size, rpc_backend_options=options)
rpc.init_rpc(name=f"work{rank}", rank=rank, world_size=world_size, rpc_backend_options=options)
def _initialize_tp_dp_process_group(self) -> None:
rank = self.get_global_rank()
@@ -147,10 +147,10 @@ class PipelineProcessGroup:
def get_chimera_all_reduce_group(self, pp_rank: int):
with self.chimera_lock:
if not hasattr(self, 'chimera_groups'):
if not hasattr(self, "chimera_groups"):
world_size = self.get_world_size()
stage_num = self.get_stage_num()
assert world_size % 2 == 0, 'world_size must be even in chimera!'
assert world_size % 2 == 0, "world_size must be even in chimera!"
self.chimera_groups = {}
for rank in range(world_size // 2):
pair = [rank, world_size - 1 - rank]

View File

@@ -1,4 +1,4 @@
from ._pipeline_schedule import ChimeraPipelineEngine, FillDrainPipelineEngine, OneFOneBPipelineEngine
from .utils import pytree_map
__all__ = ['FillDrainPipelineEngine', 'OneFOneBPipelineEngine', 'ChimeraPipelineEngine', 'pytree_map']
__all__ = ["FillDrainPipelineEngine", "OneFOneBPipelineEngine", "ChimeraPipelineEngine", "pytree_map"]

View File

@@ -12,17 +12,9 @@ from torch import autograd, nn, optim
from torch._C._distributed_rpc import PyRRef
from torch.futures import Future
from colossalai.legacy.pipeline.middleware import Partition, PartitionInputVal, PartitionOutputVal, Topo
from colossalai.legacy.pipeline.middleware import Partition, Topo
from colossalai.legacy.pipeline.pipeline_process_group import ppg
from colossalai.legacy.pipeline.rpc.utils import (
get_batch_lengths,
pyobj_map,
pytree_filter,
pytree_map,
split_batch,
tensor_shape_list,
type_detail,
)
from colossalai.legacy.pipeline.rpc.utils import get_batch_lengths, pyobj_map, pytree_filter, pytree_map, split_batch
class Phase(Enum):
@@ -33,7 +25,7 @@ class Phase(Enum):
class UniqueKey:
__slots__ = ('microbatch_id', 'phase')
__slots__ = ("microbatch_id", "phase")
microbatch_id: int
phase: Phase
@@ -48,12 +40,22 @@ class UniqueKey:
return tuple.__hash__((self.microbatch_id, self.phase))
def __repr__(self) -> str:
return f'Key(microbatch_id={self.microbatch_id}, phase={self.phase})'
return f"Key(microbatch_id={self.microbatch_id}, phase={self.phase})"
class WorkItem:
__slots__ = ('stage_id', 'phase', 'args', 'kwargs', 'output', 'refcount', 'microbatch_id', 'batch_id',
'num_microbatches', 'forward_only')
__slots__ = (
"stage_id",
"phase",
"args",
"kwargs",
"output",
"refcount",
"microbatch_id",
"batch_id",
"num_microbatches",
"forward_only",
)
stage_id: int
phase: Phase
@@ -66,50 +68,45 @@ class WorkItem:
num_microbatches: int
forward_only: bool
def __init__(self,
stage_id,
phase,
args,
kwargs,
output,
microbatch_id,
batch_id,
num_microbatches,
forward_only,
refcount=0) -> None:
def __init__(
self, stage_id, phase, args, kwargs, output, microbatch_id, batch_id, num_microbatches, forward_only, refcount=0
) -> None:
for attr_name in self.__slots__:
setattr(self, attr_name, locals()[attr_name])
class BackwardCache:
__slots__ = ('checkpoint', 'stage_input_args', 'stage_input_kwargs', 'stage_outputs')
__slots__ = ("checkpoint", "stage_input_args", "stage_input_kwargs", "stage_outputs")
checkpoint: bool
stage_input_args: Tuple[Any]
stage_input_kwargs: Dict[Any, Any]
stage_outputs: Tuple[Any]
def __init__(self,
stage_input_args: Tuple[Any],
stage_input_kwargs: Dict[Any, Any] = None,
stage_outputs: Tuple[Any] = None,
checkpoint: bool = False) -> None:
def __init__(
self,
stage_input_args: Tuple[Any],
stage_input_kwargs: Dict[Any, Any] = None,
stage_outputs: Tuple[Any] = None,
checkpoint: bool = False,
) -> None:
for arg_name in self.__slots__:
setattr(self, arg_name, locals()[arg_name])
class WorkerBase(ABC):
def __init__(self,
partition_fn: Callable,
partition_args: tuple,
pp_rank: int,
actual_stage_num: int,
num_microbatches: int,
device: str,
criterion: Callable = None,
metric: Callable = None,
checkpoint: bool = False,
data_process_func: Callable = None) -> None:
def __init__(
self,
partition_fn: Callable,
partition_args: tuple,
pp_rank: int,
actual_stage_num: int,
num_microbatches: int,
device: str,
criterion: Callable = None,
metric: Callable = None,
checkpoint: bool = False,
data_process_func: Callable = None,
) -> None:
super().__init__()
self.pp_rank = pp_rank
@@ -150,11 +147,11 @@ class WorkerBase(ABC):
self._initialize_context_container()
# main loop
self.main_loop_thread = threading.Thread(target=self._work_loop, name=f'rank_{pp_rank}', daemon=True)
self.main_loop_thread = threading.Thread(target=self._work_loop, name=f"rank_{pp_rank}", daemon=True)
self.main_loop_thread.start()
def _get_future_by_device(self):
return torch.futures.Future(devices=None if self.device in (None, 'cpu') else [self.device])
return torch.futures.Future(devices=None if self.device in (None, "cpu") else [self.device])
def _initialize_outstanding_range(self):
outstanding_range = None
@@ -199,12 +196,13 @@ class WorkerBase(ABC):
# lifecycle management for DAG scheduler
if output_work_item.phase == Phase.FORWARD:
lifecycle = len(self.get_consumer_stage_ids())
if self.is_model_output(): # an extra reference for scheduler collecting results
if self.is_model_output(): # an extra reference for scheduler collecting results
lifecycle += 1
elif output_work_item.phase == Phase.BACKWARD:
lifecycle = len(self.get_producer_stage_ids())
if self.is_model_input() and self._is_last_step(
output_work_item): # an extra reference for ensure_backward
output_work_item
): # an extra reference for ensure_backward
lifecycle += 1
else:
lifecycle = 0
@@ -234,9 +232,9 @@ class WorkerBase(ABC):
# offset supports get partial output to reduce comm costs.
def get_output_by_key(self, key: UniqueKey, ref_use=False, rank=None, offsets=None) -> Any:
output = self._get_output_all(key, ref_use, rank)
if offsets is None: # get all for non iterable output
if offsets is None: # get all for non iterable output
return output
else: # get part for iterable output
else: # get part for iterable output
output = [output[i] for i in offsets]
return output
@@ -252,12 +250,12 @@ class WorkerBase(ABC):
def get_partition(self):
with self.partition_condition_lock:
self.partition_condition_lock.wait_for(lambda: hasattr(self, 'module_partition'))
self.partition_condition_lock.wait_for(lambda: hasattr(self, "module_partition"))
return self.module_partition
def get_partition_state_dict(self):
with self.partition_condition_lock:
self.partition_condition_lock.wait_for(lambda: hasattr(self, 'module_partition'))
self.partition_condition_lock.wait_for(lambda: hasattr(self, "module_partition"))
return self.module_partition.state_dict()
def _make_args_kwargs(self, microbatch, merge=False):
@@ -293,8 +291,17 @@ class WorkerBase(ABC):
# make args and kwargs
args, kwargs = self._make_args_kwargs(microbatch)
work_item = WorkItem(self.pp_rank, Phase.FORWARD, args, kwargs, output, microbatch_id, None,
self.num_microbatches, forward_only)
work_item = WorkItem(
self.pp_rank,
Phase.FORWARD,
args,
kwargs,
output,
microbatch_id,
None,
self.num_microbatches,
forward_only,
)
with self.work_list_condition_lock:
self.work_list[key] = work_item
self.work_list_condition_lock.notify_all()
@@ -314,15 +321,25 @@ class WorkerBase(ABC):
for off in self_input_offsets:
self_arg_lst.append(arg_lst[off])
work_item = WorkItem(self.pp_rank, Phase.FORWARD, self_arg_lst, {}, output, microbatch_id, None,
self.num_microbatches, forward_only)
work_item = WorkItem(
self.pp_rank,
Phase.FORWARD,
self_arg_lst,
{},
output,
microbatch_id,
None,
self.num_microbatches,
forward_only,
)
with self.work_list_condition_lock:
self.work_list[key] = work_item
self.work_list_condition_lock.notify_all()
# put input tensor which other nodes need into output_list as Phase.INPUT
work_item_remote = WorkItem(self.pp_rank, Phase.INPUT, [], {}, arg_lst, microbatch_id, None,
self.num_microbatches, forward_only)
work_item_remote = WorkItem(
self.pp_rank, Phase.INPUT, [], {}, arg_lst, microbatch_id, None, self.num_microbatches, forward_only
)
with self.output_list_condition_lock:
self.output_list[recv_input_key] = work_item_remote
@@ -343,8 +360,17 @@ class WorkerBase(ABC):
output = self._get_future_by_device()
grad_wrt_loss = None
work_item = WorkItem(self.pp_rank, Phase.BACKWARD, grad_wrt_loss, {}, output, microbatch_id, None,
self.num_microbatches, False)
work_item = WorkItem(
self.pp_rank,
Phase.BACKWARD,
grad_wrt_loss,
{},
output,
microbatch_id,
None,
self.num_microbatches,
False,
)
self.work_list[key] = work_item
self.work_list_condition_lock.notify_all()
@@ -367,7 +393,7 @@ class WorkerBase(ABC):
producer_stage_ids = self.get_producer_stage_ids()
producer_num = len(producer_stage_ids)
if self.need_model_input():
producer_num += 1 # for input partition
producer_num += 1 # for input partition
subscribe_forward_futures: List[Future] = [None] * producer_num
# TODO(jiangziyue) get single value instead of the whole output
@@ -376,9 +402,9 @@ class WorkerBase(ABC):
producer_output_key = UniqueKey(microbatch_id, Phase.INPUT)
producer_worker_rref = self.pp_rank_to_worker_rref[producer_stage_id]
offsets = self._get_input_offsets_by_index(target_index=0)
subscribe_forward_futures[0] = producer_worker_rref.rpc_async().get_output_by_key(producer_output_key,
rank=self.pp_rank,
offsets=offsets)
subscribe_forward_futures[0] = producer_worker_rref.rpc_async().get_output_by_key(
producer_output_key, rank=self.pp_rank, offsets=offsets
)
for i in range(0, producer_num - 1):
producer_stage_id = producer_stage_ids[i]
@@ -386,11 +412,12 @@ class WorkerBase(ABC):
producer_worker_rref = self.pp_rank_to_worker_rref[producer_stage_id]
target_index = i + 1
offsets = self._get_input_offsets_by_index(target_index=target_index)
if offsets is not None and len(offsets) == 0: # no need to do rpc
if offsets is not None and len(offsets) == 0: # no need to do rpc
subscribe_forward_futures[target_index] = []
else:
subscribe_forward_futures[target_index] = producer_worker_rref.rpc_async().get_output_by_key(
producer_output_key, rank=self.pp_rank, offsets=offsets)
producer_output_key, rank=self.pp_rank, offsets=offsets
)
else:
for i in range(producer_num):
@@ -399,14 +426,24 @@ class WorkerBase(ABC):
producer_worker_rref = self.pp_rank_to_worker_rref[producer_stage_id]
target_index = i
offsets = self._get_input_offsets_by_index(target_index=target_index)
if offsets is not None and len(offsets) == 0: # no need to do rpc
if offsets is not None and len(offsets) == 0: # no need to do rpc
subscribe_forward_futures[target_index] = []
else:
subscribe_forward_futures[target_index] = producer_worker_rref.rpc_async().get_output_by_key(
producer_output_key, rank=self.pp_rank, offsets=offsets)
producer_output_key, rank=self.pp_rank, offsets=offsets
)
work_item_from_producer = WorkItem(stage_id, Phase.FORWARD, subscribe_forward_futures, {}, output,
microbatch_id, None, self.num_microbatches, forward_only)
work_item_from_producer = WorkItem(
stage_id,
Phase.FORWARD,
subscribe_forward_futures,
{},
output,
microbatch_id,
None,
self.num_microbatches,
forward_only,
)
return work_item_from_producer
@@ -441,15 +478,25 @@ class WorkerBase(ABC):
consumer_worker_rref = self.pp_rank_to_worker_rref[consumer_stage_id]
target_index = i
offsets = self._get_output_offsets_by_index(target_index=target_index)
if offsets is not None and len(offsets) == 0: # no need to do rpc
if offsets is not None and len(offsets) == 0: # no need to do rpc
subscribe_backward_futures[target_index] = []
else:
subscribe_backward_futures[target_index] = consumer_worker_rref.rpc_async().get_output_by_key(
consumer_output_key, rank=self.pp_rank, offsets=offsets)
consumer_output_key, rank=self.pp_rank, offsets=offsets
)
# flatten args
work_item_from_consumer = WorkItem(stage_id, Phase.BACKWARD, subscribe_backward_futures, {}, output,
microbatch_id, None, self.num_microbatches, False)
work_item_from_consumer = WorkItem(
stage_id,
Phase.BACKWARD,
subscribe_backward_futures,
{},
output,
microbatch_id,
None,
self.num_microbatches,
False,
)
return work_item_from_consumer
@@ -524,8 +571,8 @@ class WorkerBase(ABC):
def get_topo(self):
with self.partition_condition_lock:
self.partition_condition_lock.wait_for(lambda: hasattr(self, 'module_partition'))
if hasattr(self.module_partition, '_topo'):
self.partition_condition_lock.wait_for(lambda: hasattr(self, "module_partition"))
if hasattr(self.module_partition, "_topo"):
return self.module_partition._topo
else:
return None
@@ -564,12 +611,12 @@ class WorkerBase(ABC):
if stage_id == src_stage_id:
src_index += i
break
else: # data from input partition
else: # data from input partition
src_index = 0
# when output_len = 1, not iterable
if target_index == src_index:
if output_len == 1:
res = None # offset = None to get all outputs
res = None # offset = None to get all outputs
return res
else:
res.append(src_offset)
@@ -584,7 +631,6 @@ class WorkerBase(ABC):
consumer_stage_ids = self.get_consumer_stage_ids()
for val_list in output_vals:
# An output may be passed to many down stages.
target = None
for val_pos in val_list.get():
dst_partition_id = val_pos.partition_id
dst_offset = val_pos.offset
@@ -597,7 +643,7 @@ class WorkerBase(ABC):
break
if target_index == dst_index:
if input_len == 1:
res = None # offset = None to get all outputs
res = None # offset = None to get all outputs
return res
else:
res.append(dst_offset)
@@ -623,7 +669,7 @@ class WorkerBase(ABC):
flatten_args = []
if self.is_first_stage():
pytree_map(args_or_kwargs, fn=lambda x: flatten_args.append(x), map_all=True)
else: # get by offset
else: # get by offset
topo: Topo = self.get_topo()
self_partition_id = self.pp_rank_to_partition_id(self.pp_rank, topo)
self_partition: Partition = topo.get_partition_by_id(self_partition_id)
@@ -652,7 +698,7 @@ class WorkerBase(ABC):
if stage_id == src_stage_id:
src_index += i
break
else: # data from input partition
else: # data from input partition
src_index = 0
# when output_len = 1, not iterable
if output_len == 1:
@@ -679,7 +725,7 @@ class WorkerBase(ABC):
else:
for i, arg in enumerate(args_or_kwargs):
args_or_kwargs[i] = arg.wait()
if args_or_kwargs is not None: # get by offset
if args_or_kwargs is not None: # get by offset
flatten_args = []
topo: Topo = self.get_topo()
self_partition_id = self.pp_rank_to_partition_id(self.pp_rank, topo)
@@ -719,7 +765,7 @@ class WorkerBase(ABC):
@abstractmethod
def _get_work_item_key(self) -> UniqueKey:
"""
this method control the order of the microbatch to consume
this method control the order of the microbatch to consume
"""
def is_first_stage(self):
@@ -761,7 +807,7 @@ class WorkerBase(ABC):
kwargs = work_item.kwargs
microbatch_id = work_item.microbatch_id
forward_only = work_item.forward_only
data_process_func = getattr(self, 'data_process_func', self._default_data_process_func)
data_process_func = getattr(self, "data_process_func", self._default_data_process_func)
consume_result = None
is_first_stage = self.is_first_stage()
@@ -787,10 +833,12 @@ class WorkerBase(ABC):
else:
args_kwargs = self._get_real_args_kwargs_fwd(args)
args_kwargs = pyobj_map(args_kwargs, fn=lambda x: x.to(self.device).detach(),
process_types=torch.Tensor) # torch rpc doesn't support args or rets in GPU
args_kwargs = pyobj_map(args_kwargs, fn=lambda x: self.device,
process_types=torch.device) # change devices from last stage to current device
args_kwargs = pyobj_map(
args_kwargs, fn=lambda x: x.to(self.device).detach(), process_types=torch.Tensor
) # torch rpc doesn't support args or rets in GPU
args_kwargs = pyobj_map(
args_kwargs, fn=lambda x: self.device, process_types=torch.device
) # change devices from last stage to current device
args, kwargs = data_process_func(args_kwargs)
@@ -851,16 +899,16 @@ class WorkerBase(ABC):
use_checkpoint = False
if not forward_only:
self.microbatch_id_to_backward_cache[microbatch_id] = BackwardCache(stage_input_args,
stage_input_kwargs,
stage_outputs,
checkpoint=use_checkpoint)
consume_result = pyobj_map(consume_result, fn=lambda x: x.to('cpu'),
process_types=torch.Tensor) # torch rpc doesn't support args or rets in
self.microbatch_id_to_backward_cache[microbatch_id] = BackwardCache(
stage_input_args, stage_input_kwargs, stage_outputs, checkpoint=use_checkpoint
)
consume_result = pyobj_map(
consume_result, fn=lambda x: x.to("cpu"), process_types=torch.Tensor
) # torch rpc doesn't support args or rets in
# if not forward_only, do the backward
if not forward_only:
if is_last_stage: # if it is the last stage, trigger backward automatic
if is_last_stage: # if it is the last stage, trigger backward automatic
self._begin_backward(microbatch_id)
elif phase == Phase.BACKWARD:
@@ -872,7 +920,9 @@ class WorkerBase(ABC):
self.backward_times += 1
self.outstanding -= 1
assert microbatch_id in self.microbatch_id_to_backward_cache, f"microbatch_id {microbatch_id} not in backward cache"
assert (
microbatch_id in self.microbatch_id_to_backward_cache
), f"microbatch_id {microbatch_id} not in backward cache"
backward_cache = self.microbatch_id_to_backward_cache.pop(microbatch_id)
stage_outputs = backward_cache.stage_outputs
@@ -906,8 +956,9 @@ class WorkerBase(ABC):
filtered_grads.append(grad)
stage_outputs = filtered_outputs
grad_tensors = pyobj_map(filtered_grads, fn=lambda x: x.to(self.device),
process_types=torch.Tensor) # torch rpc doesn't support args or rets in GPU
grad_tensors = pyobj_map(
filtered_grads, fn=lambda x: x.to(self.device), process_types=torch.Tensor
) # torch rpc doesn't support args or rets in GPU
autograd.backward(stage_outputs, grad_tensors=grad_tensors)
# collect grad of input tensor
@@ -920,8 +971,8 @@ class WorkerBase(ABC):
else:
consume_result.append(None)
consume_result = pyobj_map(
consume_result, fn=lambda x: x.to('cpu'),
process_types=torch.Tensor) # torch rpc doesn't support args or rets in GPU
consume_result, fn=lambda x: x.to("cpu"), process_types=torch.Tensor
) # torch rpc doesn't support args or rets in GPU
else:
raise TypeError(f"Unknown phase appears in _consume_work_item_by_phase {phase}")
@@ -929,7 +980,7 @@ class WorkerBase(ABC):
return consume_result
def _get_store_len(self):
return f'work_list:{len(self.work_list)} output_list:{len(self.output_list)} backward_cache:{len(self.microbatch_id_to_backward_cache)} label_cache:{len(self.microbatch_id_to_labels)}'
return f"work_list:{len(self.work_list)} output_list:{len(self.output_list)} backward_cache:{len(self.microbatch_id_to_backward_cache)} label_cache:{len(self.microbatch_id_to_labels)}"
def _get_parameter_grad_sum(self):
grad_sum = 0
@@ -1014,19 +1065,20 @@ class WorkerBase(ABC):
class PipelineEngineBase(ABC, nn.Module):
def __init__(self,
worker_type,
partition_fn: Callable,
stage_num,
num_microbatches,
device: str,
use_1F1B=False,
chunk: int = 1,
criterion: Callable = None,
metric: Callable = None,
checkpoint: bool = False,
data_process_func: Callable = None) -> None:
def __init__(
self,
worker_type,
partition_fn: Callable,
stage_num,
num_microbatches,
device: str,
use_1F1B=False,
chunk: int = 1,
criterion: Callable = None,
metric: Callable = None,
checkpoint: bool = False,
data_process_func: Callable = None,
) -> None:
super().__init__()
self.worker_type = worker_type
self.partition_fn: Callable = partition_fn
@@ -1056,12 +1108,12 @@ class PipelineEngineBase(ABC, nn.Module):
data_process_func = self.data_process_func
if data_process_func is not None:
assert callable(data_process_func), "data_process_func must be a function"
assert '<locals>' not in data_process_func.__repr__(), "data_process_func must be a global function"
assert '<lambda>' not in data_process_func.__repr__(), "data_process_func cannot be a lambda expression"
assert "<locals>" not in data_process_func.__repr__(), "data_process_func must be a global function"
assert "<lambda>" not in data_process_func.__repr__(), "data_process_func cannot be a lambda expression"
sig = inspect.signature(data_process_func)
assert len(
sig.parameters
) == 2, f"length of data_process_func' arguments must be 2, receive {len(sig.parameters)} arguments instead"
assert (
len(sig.parameters) == 2
), f"length of data_process_func' arguments must be 2, receive {len(sig.parameters)} arguments instead"
def _get_actual_stage_num(self) -> int:
return self.stage_num if self.chunk == 1 else self.virtual_stage_num
@@ -1104,19 +1156,33 @@ class PipelineEngineBase(ABC, nn.Module):
partition_id = self.pp_rank_to_module_partition_id[pp_rank]
partition_args = (partition_id, chunk, actual_stage_num)
rpc_worker_id = self.pp_rank_to_rpc_worker_id[pp_rank]
if device[:4] == 'cuda':
device = f'cuda:{rpc_worker_id}'
self.pp_rank_to_worker_rref[pp_rank] = rpc.remote(rpc_worker_id,
worker_type,
args=(partition_fn, partition_args, pp_rank,
actual_stage_num, num_microbatches, device,
criterion, metric, checkpoint, data_process_func))
if device[:4] == "cuda":
device = f"cuda:{rpc_worker_id}"
self.pp_rank_to_worker_rref[pp_rank] = rpc.remote(
rpc_worker_id,
worker_type,
args=(
partition_fn,
partition_args,
pp_rank,
actual_stage_num,
num_microbatches,
device,
criterion,
metric,
checkpoint,
data_process_func,
),
)
# let each worker know global worker rref (include itself)
sync_futs = []
for pp_rank in self.pp_rank_to_worker_rref:
fut = self.pp_rank_to_worker_rref[pp_rank].rpc_async(timeout=0).sync_global_worker_rrefs(
self.pp_rank_to_worker_rref)
fut = (
self.pp_rank_to_worker_rref[pp_rank]
.rpc_async(timeout=0)
.sync_global_worker_rrefs(self.pp_rank_to_worker_rref)
)
sync_futs.append(fut)
for fut in sync_futs:
@@ -1157,8 +1223,9 @@ class PipelineEngineBase(ABC, nn.Module):
def get_output_pp_ranks(self) -> List[int]:
return [self._get_actual_stage_num() - 1]
def _consume_constraint(self, microbatch_id: int, forward_only: bool, input_pp_ranks: List[int],
output_pp_ranks: List[int], ret_future):
def _consume_constraint(
self, microbatch_id: int, forward_only: bool, input_pp_ranks: List[int], output_pp_ranks: List[int], ret_future
):
actual_stage_num = self._get_actual_stage_num()
use_1F1B = self.use_1F1B
if microbatch_id >= actual_stage_num:
@@ -1206,7 +1273,8 @@ class PipelineEngineBase(ABC, nn.Module):
worker_rref = self.pp_rank_to_worker_rref[pp_rank]
key = UniqueKey(self.num_microbatches - 1, Phase.BACKWARD)
fut = worker_rref.rpc_async().get_output_by_key(
key, offsets=[]) # only ensure the res exists, no need for real data.
key, offsets=[]
) # only ensure the res exists, no need for real data.
backward_result.append(fut)
for fut in backward_result:
@@ -1244,11 +1312,14 @@ class PipelineEngineBase(ABC, nn.Module):
if labels is not None and not forward_only:
assert hasattr(
self, 'optimizer_class'), "call `initialize_optimizer` to initialize optimizer before forward_backward"
self, "optimizer_class"
), "call `initialize_optimizer` to initialize optimizer before forward_backward"
num_microbatches = self.num_microbatches
assert batch_length >= num_microbatches, "num_microbatches is greater than the size of a batch, which is illegal"
assert (
batch_length >= num_microbatches
), "num_microbatches is greater than the size of a batch, which is illegal"
microbatch_size = math.ceil(batch_length / num_microbatches)
device = self.device
@@ -1285,10 +1356,10 @@ class PipelineEngineBase(ABC, nn.Module):
# collect forward result
forward_result = self._collect_forward_result(output_pp_ranks, ret_future)
if not forward_only and hasattr(self, 'optimizer_class'):
if not forward_only and hasattr(self, "optimizer_class"):
self.step()
self._reset_worker() # reset worker attributes for next batch
self._reset_worker() # reset worker attributes for next batch
return forward_result
def initialize_optimizer(self, optimizer_class: type, **kwargs):

View File

@@ -2,7 +2,6 @@ import threading
from typing import Callable, Dict, List
import torch
import torch.distributed as dist
from torch._C._distributed_rpc import PyRRef
from torch.futures import Future
@@ -15,7 +14,6 @@ from colossalai.legacy.pipeline.rpc._pipeline_base import Phase, PipelineEngineB
class FillDrainWorker(WorkerBase):
def _get_work_item_key(self) -> UniqueKey:
# execute backward first (if backward phase in work_list)
num_microbatches = self.num_microbatches
@@ -33,29 +31,40 @@ class FillDrainWorker(WorkerBase):
class FillDrainPipelineEngine(PipelineEngineBase):
def __init__(self,
partition_fn: Callable,
stage_num: int,
num_microbatches: int,
device: str,
chunk: int = 1,
criterion: Callable = None,
metric: Callable = None,
checkpoint: bool = False,
data_process_func: Callable = None) -> None:
def __init__(
self,
partition_fn: Callable,
stage_num: int,
num_microbatches: int,
device: str,
chunk: int = 1,
criterion: Callable = None,
metric: Callable = None,
checkpoint: bool = False,
data_process_func: Callable = None,
) -> None:
if chunk > 1:
assert num_microbatches % stage_num == 0, \
"if you use interleaving strategy, make sure 'num_microbatches' is a multiple of stage_num!"
assert (
num_microbatches % stage_num == 0
), "if you use interleaving strategy, make sure 'num_microbatches' is a multiple of stage_num!"
use_1F1B = False
super().__init__(FillDrainWorker, partition_fn, stage_num, num_microbatches, device, use_1F1B, chunk, criterion,
metric, checkpoint, data_process_func)
super().__init__(
FillDrainWorker,
partition_fn,
stage_num,
num_microbatches,
device,
use_1F1B,
chunk,
criterion,
metric,
checkpoint,
data_process_func,
)
class OneFOneBWorker(WorkerBase):
def _get_work_item_key(self) -> UniqueKey:
# execute backward first (if backward phase in work_list)
pp_rank = self.pp_rank
@@ -77,8 +86,7 @@ class OneFOneBWorker(WorkerBase):
# change outstanding_range at:
# 1. forward times reach actual_stage_num, this is the end of continuous forward
# 2. forward times reach num_microbatches, this is the end of 1F1B mode
if not is_last_stage and \
target_key.phase == Phase.FORWARD:
if not is_last_stage and target_key.phase == Phase.FORWARD:
if target_key.microbatch_id == actual_stage_num - 1 and num_microbatches > 2:
# Why need num_microbatches > 2 ? Because there is no steady stage when num_microbatches <= 2
outstanding_min = actual_stage_num - pp_rank - 1
@@ -91,30 +99,41 @@ class OneFOneBWorker(WorkerBase):
class OneFOneBPipelineEngine(PipelineEngineBase):
def __init__(self,
partition_fn: Callable,
stage_num: int,
num_microbatches: int,
device: str,
chunk: int = 1,
criterion: Callable = None,
metric: Callable = None,
checkpoint: bool = False,
data_process_func: Callable = None) -> None:
def __init__(
self,
partition_fn: Callable,
stage_num: int,
num_microbatches: int,
device: str,
chunk: int = 1,
criterion: Callable = None,
metric: Callable = None,
checkpoint: bool = False,
data_process_func: Callable = None,
) -> None:
if chunk > 1:
assert num_microbatches % stage_num == 0, \
"if you use interleaving strategy, make sure 'num_microbatches' is a multiple of stage_num!"
assert (
num_microbatches % stage_num == 0
), "if you use interleaving strategy, make sure 'num_microbatches' is a multiple of stage_num!"
# assert num_microbatches > stage_num * chunk, "num_microbatches must be greater than stage_num * chunk"
use_1F1B = True
super().__init__(OneFOneBWorker, partition_fn, stage_num, num_microbatches, device, use_1F1B, chunk, criterion,
metric, checkpoint, data_process_func)
super().__init__(
OneFOneBWorker,
partition_fn,
stage_num,
num_microbatches,
device,
use_1F1B,
chunk,
criterion,
metric,
checkpoint,
data_process_func,
)
class ChimeraWorker(WorkerBase):
def _get_producer_consumer(self) -> None:
rank = self.pp_rank
min_pp_rank = (rank // self.actual_stage_num) * self.actual_stage_num
@@ -143,11 +162,12 @@ class ChimeraWorker(WorkerBase):
forward_block_size = 1 if self.num_microbatches < stage_num else self.num_microbatches // stage_num
forward_block_num = self.forward_times // forward_block_size
if self.forward_times >= real_microbatch_num or \
((pp_rank + 1) % stage_num == 0 and forward_block_num > self.backward_times):
if self.forward_times >= real_microbatch_num or (
(pp_rank + 1) % stage_num == 0 and forward_block_num > self.backward_times
):
target_phase = Phase.BACKWARD
target_microbatch_id = self.backward_times
else: # others
else: # others
target_phase = Phase.FORWARD
target_microbatch_id = self.forward_times
@@ -168,7 +188,7 @@ class ChimeraWorker(WorkerBase):
# from corresponding up stage
pp_rank = self.pp_rank
stage_num = self.actual_stage_num
device = self.device
self.device
if pp_rank < stage_num:
super()._initialize_partition()
else:
@@ -242,27 +262,38 @@ class ChimeraWorker(WorkerBase):
class ChimeraPipelineEngine(PipelineEngineBase):
def __init__(self,
partition_fn: Callable,
stage_num: int,
num_microbatches: int,
device: str,
criterion: Callable = None,
metric: Callable = None,
checkpoint: bool = False,
data_process_func: Callable = None) -> None:
assert num_microbatches % stage_num == 0, \
"In Chimera, num_microbatches must be the multiply of stage_num!"
def __init__(
self,
partition_fn: Callable,
stage_num: int,
num_microbatches: int,
device: str,
criterion: Callable = None,
metric: Callable = None,
checkpoint: bool = False,
data_process_func: Callable = None,
) -> None:
assert num_microbatches % stage_num == 0, "In Chimera, num_microbatches must be the multiply of stage_num!"
use_1F1B = False
chunk = 1
super().__init__(ChimeraWorker, partition_fn, stage_num, num_microbatches, device, use_1F1B, chunk, criterion,
metric, checkpoint, data_process_func)
super().__init__(
ChimeraWorker,
partition_fn,
stage_num,
num_microbatches,
device,
use_1F1B,
chunk,
criterion,
metric,
checkpoint,
data_process_func,
)
def _consume_constraint(self, microbatch_id: int, forward_only: bool, input_pp_ranks: List[int],
output_pp_ranks: List[int], ret_future):
def _consume_constraint(
self, microbatch_id: int, forward_only: bool, input_pp_ranks: List[int], output_pp_ranks: List[int], ret_future
):
pass
def _create_pp_rank_to_rpc_worker_id(self) -> None:

View File

@@ -1,7 +1,7 @@
import argparse
import os
import warnings
from typing import Any, Callable, Dict, List, Tuple, Type, Union
from typing import Any, Callable, Tuple, Type, Union
import torch
import torch.distributed.rpc as rpc
@@ -61,7 +61,7 @@ def get_batch_lengths(batch):
def split_batch(batch: Any, start, stop, device: str):
if device == 'cuda':
if device == "cuda":
fn = lambda x: x[start:stop].cuda()
else:
fn = lambda x: x[start:stop]
@@ -102,8 +102,8 @@ def get_real_args_kwargs(args_or_kwargs):
def run_worker(rank, args, master_func):
os.environ['MASTER_ADDR'] = args.master_addr
os.environ['MASTER_PORT'] = args.master_port
os.environ["MASTER_ADDR"] = args.master_addr
os.environ["MASTER_PORT"] = args.master_port
device = args.device
world_size = args.world_size
@@ -112,15 +112,17 @@ def run_worker(rank, args, master_func):
num_worker_threads = args.num_worker_threads
host = args.master_addr
port = args.master_port
backend = 'nccl' if device == 'cuda' else 'gloo'
backend = "nccl" if device == "cuda" else "gloo"
launch(dict(), rank, world_size, host, int(port), backend, verbose=False)
ppg.set_global_info(rank=rank,
world_size=world_size,
dp_degree=dp_degree,
tp_degree=tp_degree,
num_worker_threads=num_worker_threads,
device=device)
ppg.set_global_info(
rank=rank,
world_size=world_size,
dp_degree=dp_degree,
tp_degree=tp_degree,
num_worker_threads=num_worker_threads,
device=device,
)
ppg.args = args
# in rpc mode, only rank 0 is needed to be coded
if rank == 0:
@@ -139,17 +141,17 @@ def rpc_run(args, master_func):
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument('--epoch', type=int, default=1)
parser.add_argument('--world_size', type=int, default=2)
parser.add_argument('--batch_size', type=int, default=16)
parser.add_argument('--dp_degree', type=int, default=1)
parser.add_argument('--tp_degree', type=int, default=1)
parser.add_argument('--num_microbatches', type=int, default=2)
parser.add_argument('--chunk', type=int, default=1)
parser.add_argument('--use_checkpoint', action='store_true')
parser.add_argument('--optimizer', type=str, choices=['SGD', 'Adam', 'RMSprop'], default='SGD')
parser.add_argument('--device', type=str, choices=['cpu', 'cuda'], default='cuda')
parser.add_argument('--master_addr', type=str, default='localhost')
parser.add_argument('--master_port', type=str, default='29020')
parser.add_argument('--num_worker_threads', type=int, default=128)
parser.add_argument("--epoch", type=int, default=1)
parser.add_argument("--world_size", type=int, default=2)
parser.add_argument("--batch_size", type=int, default=16)
parser.add_argument("--dp_degree", type=int, default=1)
parser.add_argument("--tp_degree", type=int, default=1)
parser.add_argument("--num_microbatches", type=int, default=2)
parser.add_argument("--chunk", type=int, default=1)
parser.add_argument("--use_checkpoint", action="store_true")
parser.add_argument("--optimizer", type=str, choices=["SGD", "Adam", "RMSprop"], default="SGD")
parser.add_argument("--device", type=str, choices=["cpu", "cuda"], default="cuda")
parser.add_argument("--master_addr", type=str, default="localhost")
parser.add_argument("--master_port", type=str, default="29020")
parser.add_argument("--num_worker_threads", type=int, default=128)
return parser.parse_args()

View File

@@ -38,8 +38,7 @@ def _binary_partition(weights: List, start: int, end: int):
def _heap_addition(weights: List, intervals: int, add_cnt: int):
"""
"""
""" """
def _heap_push(heap, st, ed):
value = weights[ed - 1]
@@ -113,8 +112,9 @@ def _binary_search(weights, num):
def partition_uniform(num_items, pipeline_parallel_size, num_chunks):
assert num_items % num_chunks == 0, \
"Layer length should be divided by the number of chunks, otherwise parameter method is recommended"
assert (
num_items % num_chunks == 0
), "Layer length should be divided by the number of chunks, otherwise parameter method is recommended"
logger = get_dist_logger()
parts = [[] for _ in range(pipeline_parallel_size)]
@@ -162,7 +162,7 @@ def build_kwargs_for_module(function, input_tensor, kw_dict):
elif isinstance(input_tensor, torch.Tensor):
kwargs_offset = 1
elif isinstance(input_tensor, (tuple, OrderedDict)):
#assert isinstance(input_tensor, tuple), f'input_tensor should be a torch.Tensor or a tuple object.'
# assert isinstance(input_tensor, tuple), f'input_tensor should be a torch.Tensor or a tuple object.'
# Huggingface will take their own structures based on OrderedDict as the output
# between layers so we've to close this check.
kwargs_offset = len(input_tensor)
@@ -204,21 +204,21 @@ def exec_func_with_kwargs(func, kw_dict, input_tensor, kwargs):
kwargs[k] = rst
return input_tensor
if isinstance(input_tensor, tuple):
assert len(input_tensor) > 0, f'input_tensor should not be empty, when kw_dict is None.'
assert len(input_tensor) > 0, f"input_tensor should not be empty, when kw_dict is None."
sig = inspect.signature(func)
func_args_num = len(sig.parameters)
assert func_args_num <= len(
input_tensor), f'func requires {func_args_num} arguments, but input_tensors only have {len(input_tensor)}.'
input_tensor
), f"func requires {func_args_num} arguments, but input_tensors only have {len(input_tensor)}."
if func_args_num < len(input_tensor):
return func(*input_tensor[:func_args_num])
else:
return func(*input_tensor)
assert isinstance(input_tensor, torch.Tensor), 'input_tensor should be a type of torch.Tensor or tuple.'
assert isinstance(input_tensor, torch.Tensor), "input_tensor should be a type of torch.Tensor or tuple."
return func(input_tensor)
def exec_funcs_with_kwargs(func_dict, func_key, input_tensor, kwargs):
assert func_key in func_dict, f"{func_key} is not in the function_dict."
funcs_to_exec = func_dict[func_key]
if isinstance(funcs_to_exec, list):
@@ -243,7 +243,7 @@ def call_module(module, args=None, kwargs=None):
forward_func = module.forward
sig = inspect.signature(forward_func)
param_nums = len(sig.parameters)
feed_nums = len(args) + len(kwargs)
len(args) + len(kwargs)
args_needed_nums = param_nums - len(kwargs)
args_needed = args[:args_needed_nums]
if isinstance(module, CheckpointModule):
@@ -256,17 +256,17 @@ def call_module(module, args=None, kwargs=None):
def customized_partition(exec_seq):
'''
"""
This function will analyze the exec_seq. In the exec_seq, users will use 'SPLIT_NODE' as an
annotation to note the partition point.
'''
"""
customized_parts = {}
start = 0
stop = 0
rank = 0
for element in exec_seq:
if isinstance(element, str):
if element == 'SPLIT_NODE':
if element == "SPLIT_NODE":
customized_parts[rank] = [(start, stop)]
start = stop
rank += 1