mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-07 03:52:01 +00:00
[misc] update pre-commit and run all files (#4752)
* [misc] update pre-commit * [misc] run pre-commit * [misc] remove useless configuration files * [misc] ignore cuda for clang-format
This commit is contained in:
@@ -1,4 +1,4 @@
|
||||
from .layer_spec import LayerSpec
|
||||
from .pipelinable import PipelinableContext, PipelinableModel
|
||||
|
||||
__all__ = ['PipelinableModel', 'PipelinableContext', 'LayerSpec']
|
||||
__all__ = ["PipelinableModel", "PipelinableContext", "LayerSpec"]
|
||||
|
@@ -4,9 +4,7 @@ from colossalai.utils.model.utils import call_to_str
|
||||
|
||||
|
||||
class LayerSpec:
|
||||
"""
|
||||
|
||||
"""
|
||||
""" """
|
||||
|
||||
def __init__(self, typename, *module_args, **module_kwargs):
|
||||
self.typename = typename
|
||||
@@ -16,7 +14,7 @@ class LayerSpec:
|
||||
self._param_count = 0
|
||||
|
||||
if not issubclass(typename, torch.nn.Module):
|
||||
raise RuntimeError('LayerSpec only supports torch.nn.Module types.')
|
||||
raise RuntimeError("LayerSpec only supports torch.nn.Module types.")
|
||||
|
||||
def __repr__(self):
|
||||
return call_to_str(self.typename.__name__, self.module_args, self.module_kwargs)
|
||||
|
@@ -1,3 +1,3 @@
|
||||
from .topo import Partition, PartitionInputVal, PartitionOutputVal, Topo
|
||||
|
||||
__all__ = ['Topo', 'Partition', 'PartitionOutputVal', 'PartitionInputVal']
|
||||
__all__ = ["Topo", "Partition", "PartitionOutputVal", "PartitionInputVal"]
|
||||
|
@@ -1,3 +1,3 @@
|
||||
from .fx import get_topology as get_fx_topology
|
||||
|
||||
__all__ = ['get_fx_topology']
|
||||
__all__ = ["get_fx_topology"]
|
||||
|
@@ -10,7 +10,7 @@ def partition_name_to_id(partition_name, is_input=False, is_output=False):
|
||||
elif is_output:
|
||||
partition_id = 1
|
||||
else:
|
||||
prefix = 'submod_'
|
||||
prefix = "submod_"
|
||||
partition_id = int(partition_name.split(prefix)[-1]) + 2
|
||||
return partition_id
|
||||
|
||||
@@ -27,10 +27,10 @@ def partition_name_to_id(partition_name, is_input=False, is_output=False):
|
||||
|
||||
def find_input_in_partition(node, partitions, input_partitions=None):
|
||||
p_input_val = None
|
||||
direct_def = not node.name.startswith('getitem')
|
||||
direct_def = not node.name.startswith("getitem")
|
||||
# search in input
|
||||
if direct_def and input_partitions is not None:
|
||||
partition_id = partition_name_to_id('', is_input=True)
|
||||
partition_id = partition_name_to_id("", is_input=True)
|
||||
for i, input_node in enumerate(input_partitions):
|
||||
if input_node == node:
|
||||
p_input_val = PartitionInputVal(partition_id=partition_id, offset=i)
|
||||
@@ -57,7 +57,7 @@ def find_input_in_partition(node, partitions, input_partitions=None):
|
||||
def find_output_in_partition(node, partitions, output_partitions=None):
|
||||
p_output_val = PartitionOutputVal()
|
||||
for user in node.users:
|
||||
direct_use = not user.name.startswith('getitem')
|
||||
direct_use = not user.name.startswith("getitem")
|
||||
# user is mid partition
|
||||
for partition in partitions:
|
||||
# direct call
|
||||
@@ -82,7 +82,7 @@ def find_output_in_partition(node, partitions, output_partitions=None):
|
||||
output_node = output_partitions[0]
|
||||
if user.op == output_node.op:
|
||||
output_keys = {}
|
||||
partition_id = partition_name_to_id('', is_output=True)
|
||||
partition_id = partition_name_to_id("", is_output=True)
|
||||
torch.fx.graph.map_arg(output_node.args[0], lambda n: output_keys.setdefault(n))
|
||||
for i, arg in enumerate(output_keys):
|
||||
if arg == node:
|
||||
@@ -99,11 +99,11 @@ def get_topology(gm: GraphModule):
|
||||
partitions = []
|
||||
output_partitions = []
|
||||
for node in gm.graph.nodes:
|
||||
if node.op == 'placeholder':
|
||||
if node.op == "placeholder":
|
||||
input_partitions.append(node)
|
||||
elif node.name.startswith('submod_'):
|
||||
elif node.name.startswith("submod_"):
|
||||
partitions.append(node)
|
||||
elif node.op == 'output':
|
||||
elif node.op == "output":
|
||||
output_partitions.append(node)
|
||||
else:
|
||||
continue
|
||||
@@ -127,7 +127,7 @@ def get_topology(gm: GraphModule):
|
||||
# set output for submodule
|
||||
direct_use = True
|
||||
for user in partition.users:
|
||||
if user.name.startswith('getitem'):
|
||||
if user.name.startswith("getitem"):
|
||||
direct_use = False
|
||||
break
|
||||
if direct_use:
|
||||
@@ -146,7 +146,8 @@ def get_topology(gm: GraphModule):
|
||||
topo_output_partition = Partition()
|
||||
torch.fx.graph.map_arg(
|
||||
partition.args[0],
|
||||
lambda n: topo_output_partition.add_input_val(find_input_in_partition(n, partitions, input_partitions)))
|
||||
lambda n: topo_output_partition.add_input_val(find_input_in_partition(n, partitions, input_partitions)),
|
||||
)
|
||||
topo.set_partitions(partition_id=1, partition=topo_output_partition)
|
||||
topo.set_output_partition_id(partition_id=1)
|
||||
|
||||
|
@@ -10,7 +10,7 @@ class ValPosition:
|
||||
offset: int
|
||||
|
||||
def __str__(self) -> str:
|
||||
res = f'[partition_id:{self.partition_id},offset:{self.offset}]'
|
||||
res = f"[partition_id:{self.partition_id},offset:{self.offset}]"
|
||||
return res
|
||||
|
||||
def __repr__(self) -> str:
|
||||
@@ -18,7 +18,6 @@ class ValPosition:
|
||||
|
||||
|
||||
class PartitionInputVal(object):
|
||||
|
||||
def __init__(self, partition_id, offset) -> None:
|
||||
# every input from which partition_id and which offset
|
||||
val_pos = ValPosition(partition_id, offset)
|
||||
@@ -28,8 +27,8 @@ class PartitionInputVal(object):
|
||||
return self._from_partition_and_offset
|
||||
|
||||
def __str__(self) -> str:
|
||||
res = ''
|
||||
res += f'<-({self._from_partition_and_offset})'
|
||||
res = ""
|
||||
res += f"<-({self._from_partition_and_offset})"
|
||||
return res
|
||||
|
||||
def __repr__(self) -> str:
|
||||
@@ -37,7 +36,6 @@ class PartitionInputVal(object):
|
||||
|
||||
|
||||
class PartitionOutputVal(object):
|
||||
|
||||
def __init__(self) -> None:
|
||||
# every output to which partition_id and which offset
|
||||
self._to_partition_and_offset: List[ValPosition] = []
|
||||
@@ -50,11 +48,11 @@ class PartitionOutputVal(object):
|
||||
return self._to_partition_and_offset
|
||||
|
||||
def __str__(self) -> str:
|
||||
res = ''
|
||||
res += '->('
|
||||
res = ""
|
||||
res += "->("
|
||||
for val_pos in self._to_partition_and_offset:
|
||||
res += f'{val_pos},'
|
||||
res += ')'
|
||||
res += f"{val_pos},"
|
||||
res += ")"
|
||||
return res
|
||||
|
||||
def __repr__(self) -> str:
|
||||
@@ -62,7 +60,6 @@ class PartitionOutputVal(object):
|
||||
|
||||
|
||||
class Partition(object):
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._input_vals: List[PartitionInputVal] = []
|
||||
self._output_vals: List[PartitionOutputVal] = []
|
||||
@@ -110,16 +107,16 @@ class Partition(object):
|
||||
return res
|
||||
|
||||
def __str__(self) -> str:
|
||||
res = ''
|
||||
res += f' input:\n'
|
||||
res += f' length:{len(self._input_vals)}\n'
|
||||
res = ""
|
||||
res += f" input:\n"
|
||||
res += f" length:{len(self._input_vals)}\n"
|
||||
for i, input_val in enumerate(self._input_vals):
|
||||
res += f' offset={i}:{input_val}\n'
|
||||
res += f" offset={i}:{input_val}\n"
|
||||
|
||||
res += f' output:\n'
|
||||
res += f' length:{len(self._output_vals)}\n'
|
||||
res += f" output:\n"
|
||||
res += f" length:{len(self._output_vals)}\n"
|
||||
for i, output_val in enumerate(self._output_vals):
|
||||
res += f' offset={i}:{output_val}\n'
|
||||
res += f" offset={i}:{output_val}\n"
|
||||
|
||||
return res
|
||||
|
||||
@@ -140,7 +137,6 @@ class Partition(object):
|
||||
# _input_partition_id: the key represents input_partition
|
||||
# _output_partition_id: the key represents output_partition
|
||||
class Topo(object):
|
||||
|
||||
def __init__(self, input_partition_id=None, output_partition_id=None) -> None:
|
||||
self._partitions: Dict[int, Partition] = {}
|
||||
self._input_partition_id = input_partition_id
|
||||
@@ -162,7 +158,7 @@ class Topo(object):
|
||||
self._partitions[partition_id] = partition
|
||||
|
||||
def get_mid_partitions(self):
|
||||
res = {} #{partition_id: Partition}
|
||||
res = {} # {partition_id: Partition}
|
||||
for partition_id, partition in self._partitions.items():
|
||||
if self._input_partition_id == partition_id or self._output_partition_id == partition_id:
|
||||
continue
|
||||
@@ -186,27 +182,27 @@ class Topo(object):
|
||||
return self._partitions[partition_id]
|
||||
|
||||
def __str__(self) -> str:
|
||||
res = ''
|
||||
res = ""
|
||||
if len(self._partitions) == 0:
|
||||
return 'Empty Topo Graph.'
|
||||
return "Empty Topo Graph."
|
||||
|
||||
input_part = self.get_input_partition()
|
||||
if input_part is not None:
|
||||
res += '{\n'
|
||||
res += f'InputPartition:\n partition_id={self._input_partition_id}\n{input_part}'
|
||||
res += '}\n'
|
||||
res += "{\n"
|
||||
res += f"InputPartition:\n partition_id={self._input_partition_id}\n{input_part}"
|
||||
res += "}\n"
|
||||
|
||||
mid_parts = self.get_mid_partitions()
|
||||
for i, (partition_id, part) in enumerate(mid_parts.items()):
|
||||
res += '{\n'
|
||||
res += f'SubPartition_{i}:\n partition_id={partition_id}\n {part}'
|
||||
res += '}\n'
|
||||
res += "{\n"
|
||||
res += f"SubPartition_{i}:\n partition_id={partition_id}\n {part}"
|
||||
res += "}\n"
|
||||
|
||||
output_part = self.get_output_partition()
|
||||
if output_part is not None:
|
||||
res += '{\n'
|
||||
res += f'OutputPartition:\n partition_id={self._output_partition_id}\n{output_part}'
|
||||
res += '}\n'
|
||||
res += "{\n"
|
||||
res += f"OutputPartition:\n partition_id={self._output_partition_id}\n{output_part}"
|
||||
res += "}\n"
|
||||
|
||||
return res
|
||||
|
||||
|
@@ -132,8 +132,8 @@ class PipelinableContext(InsertPostInitMethodToModuleSubClasses):
|
||||
for child in self._root_children:
|
||||
layer_spec = self._layer_spec_dict[id(child)]
|
||||
if layer_spec.typename in (
|
||||
torch.nn.modules.container.ModuleList,
|
||||
torch.nn.modules.container.Sequential,
|
||||
torch.nn.modules.container.ModuleList,
|
||||
torch.nn.modules.container.Sequential,
|
||||
):
|
||||
for child_in_container in layer_spec.children:
|
||||
self._layer_spec_list.append(self._layer_spec_dict[id(child_in_container)])
|
||||
@@ -198,8 +198,9 @@ class PipelinableContext(InsertPostInitMethodToModuleSubClasses):
|
||||
param_counts.append(layer_spec.count_params())
|
||||
parts = partition_balanced(param_counts, pipeline_size, num_chunks)[rank]
|
||||
elif self._policy == "customized":
|
||||
assert (self._exec_seq
|
||||
is not None), f"An explicit exec_seq must be defined by user in customized policy mode."
|
||||
assert (
|
||||
self._exec_seq is not None
|
||||
), f"An explicit exec_seq must be defined by user in customized policy mode."
|
||||
self.customized_parts = customized_partition(self._exec_seq)
|
||||
assert len(self.customized_parts) == gpc.get_world_size(
|
||||
ParallelMode.PIPELINE
|
||||
@@ -226,14 +227,14 @@ class PipelinableContext(InsertPostInitMethodToModuleSubClasses):
|
||||
elif (layer, "behind") in self._func_dict:
|
||||
behind_func_dict_in_partition[id(module)] = self._func_dict[(layer, "behind")]
|
||||
module_list_in_partition = torch.nn.ModuleList(module_list_in_partition)
|
||||
pipeline_model = PipelinableModel(module_list_in_partition, front_func_dict_in_partition,
|
||||
behind_func_dict_in_partition)
|
||||
pipeline_model = PipelinableModel(
|
||||
module_list_in_partition, front_func_dict_in_partition, behind_func_dict_in_partition
|
||||
)
|
||||
|
||||
return pipeline_model
|
||||
|
||||
|
||||
class PipelinableModel(torch.nn.Module):
|
||||
|
||||
def __init__(self, module_list, front_func_dict, behind_func_dict):
|
||||
super().__init__()
|
||||
self._module_list = module_list
|
||||
|
@@ -1,6 +1,5 @@
|
||||
import os
|
||||
import threading
|
||||
from typing import Dict, List, Tuple
|
||||
from typing import List
|
||||
|
||||
import torch.distributed as dist
|
||||
from torch.distributed import rpc
|
||||
@@ -14,14 +13,15 @@ class PipelineProcessGroup:
|
||||
def __init__(self) -> None:
|
||||
self.is_initialize = False
|
||||
|
||||
def set_global_info(self,
|
||||
rank: int,
|
||||
world_size: int,
|
||||
dp_degree: int = 1,
|
||||
tp_degree: int = 1,
|
||||
num_worker_threads: int = 1,
|
||||
device: str = "cuda") -> None:
|
||||
|
||||
def set_global_info(
|
||||
self,
|
||||
rank: int,
|
||||
world_size: int,
|
||||
dp_degree: int = 1,
|
||||
tp_degree: int = 1,
|
||||
num_worker_threads: int = 1,
|
||||
device: str = "cuda",
|
||||
) -> None:
|
||||
device_mesh_size = dp_degree * tp_degree
|
||||
assert world_size % device_mesh_size == 0, "world_size must be the multiple of dp_degree * tp_degree !!!"
|
||||
self._num_worker_threads = num_worker_threads
|
||||
@@ -60,8 +60,8 @@ class PipelineProcessGroup:
|
||||
device = self.device
|
||||
world_size = self.get_world_size()
|
||||
rank = self.get_global_rank()
|
||||
backend = 'nccl' if device == 'cuda' else 'gloo'
|
||||
dist.init_process_group(backend, world_size=world_size, rank=rank, group_name='main_group')
|
||||
backend = "nccl" if device == "cuda" else "gloo"
|
||||
dist.init_process_group(backend, world_size=world_size, rank=rank, group_name="main_group")
|
||||
|
||||
def _initialize_pp_process_group(self) -> None:
|
||||
rank = self.get_global_rank()
|
||||
@@ -71,9 +71,9 @@ class PipelineProcessGroup:
|
||||
options = rpc.TensorPipeRpcBackendOptions(num_worker_threads=self._num_worker_threads)
|
||||
|
||||
for pp_rank in self._pp_ranks:
|
||||
options.set_device_map(f'work{pp_rank}', {rank: pp_rank})
|
||||
options.set_device_map(f"work{pp_rank}", {rank: pp_rank})
|
||||
|
||||
rpc.init_rpc(name=f'work{rank}', rank=rank, world_size=world_size, rpc_backend_options=options)
|
||||
rpc.init_rpc(name=f"work{rank}", rank=rank, world_size=world_size, rpc_backend_options=options)
|
||||
|
||||
def _initialize_tp_dp_process_group(self) -> None:
|
||||
rank = self.get_global_rank()
|
||||
@@ -147,10 +147,10 @@ class PipelineProcessGroup:
|
||||
|
||||
def get_chimera_all_reduce_group(self, pp_rank: int):
|
||||
with self.chimera_lock:
|
||||
if not hasattr(self, 'chimera_groups'):
|
||||
if not hasattr(self, "chimera_groups"):
|
||||
world_size = self.get_world_size()
|
||||
stage_num = self.get_stage_num()
|
||||
assert world_size % 2 == 0, 'world_size must be even in chimera!'
|
||||
assert world_size % 2 == 0, "world_size must be even in chimera!"
|
||||
self.chimera_groups = {}
|
||||
for rank in range(world_size // 2):
|
||||
pair = [rank, world_size - 1 - rank]
|
||||
|
@@ -1,4 +1,4 @@
|
||||
from ._pipeline_schedule import ChimeraPipelineEngine, FillDrainPipelineEngine, OneFOneBPipelineEngine
|
||||
from .utils import pytree_map
|
||||
|
||||
__all__ = ['FillDrainPipelineEngine', 'OneFOneBPipelineEngine', 'ChimeraPipelineEngine', 'pytree_map']
|
||||
__all__ = ["FillDrainPipelineEngine", "OneFOneBPipelineEngine", "ChimeraPipelineEngine", "pytree_map"]
|
||||
|
@@ -12,17 +12,9 @@ from torch import autograd, nn, optim
|
||||
from torch._C._distributed_rpc import PyRRef
|
||||
from torch.futures import Future
|
||||
|
||||
from colossalai.legacy.pipeline.middleware import Partition, PartitionInputVal, PartitionOutputVal, Topo
|
||||
from colossalai.legacy.pipeline.middleware import Partition, Topo
|
||||
from colossalai.legacy.pipeline.pipeline_process_group import ppg
|
||||
from colossalai.legacy.pipeline.rpc.utils import (
|
||||
get_batch_lengths,
|
||||
pyobj_map,
|
||||
pytree_filter,
|
||||
pytree_map,
|
||||
split_batch,
|
||||
tensor_shape_list,
|
||||
type_detail,
|
||||
)
|
||||
from colossalai.legacy.pipeline.rpc.utils import get_batch_lengths, pyobj_map, pytree_filter, pytree_map, split_batch
|
||||
|
||||
|
||||
class Phase(Enum):
|
||||
@@ -33,7 +25,7 @@ class Phase(Enum):
|
||||
|
||||
|
||||
class UniqueKey:
|
||||
__slots__ = ('microbatch_id', 'phase')
|
||||
__slots__ = ("microbatch_id", "phase")
|
||||
microbatch_id: int
|
||||
phase: Phase
|
||||
|
||||
@@ -48,12 +40,22 @@ class UniqueKey:
|
||||
return tuple.__hash__((self.microbatch_id, self.phase))
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f'Key(microbatch_id={self.microbatch_id}, phase={self.phase})'
|
||||
return f"Key(microbatch_id={self.microbatch_id}, phase={self.phase})"
|
||||
|
||||
|
||||
class WorkItem:
|
||||
__slots__ = ('stage_id', 'phase', 'args', 'kwargs', 'output', 'refcount', 'microbatch_id', 'batch_id',
|
||||
'num_microbatches', 'forward_only')
|
||||
__slots__ = (
|
||||
"stage_id",
|
||||
"phase",
|
||||
"args",
|
||||
"kwargs",
|
||||
"output",
|
||||
"refcount",
|
||||
"microbatch_id",
|
||||
"batch_id",
|
||||
"num_microbatches",
|
||||
"forward_only",
|
||||
)
|
||||
|
||||
stage_id: int
|
||||
phase: Phase
|
||||
@@ -66,50 +68,45 @@ class WorkItem:
|
||||
num_microbatches: int
|
||||
forward_only: bool
|
||||
|
||||
def __init__(self,
|
||||
stage_id,
|
||||
phase,
|
||||
args,
|
||||
kwargs,
|
||||
output,
|
||||
microbatch_id,
|
||||
batch_id,
|
||||
num_microbatches,
|
||||
forward_only,
|
||||
refcount=0) -> None:
|
||||
def __init__(
|
||||
self, stage_id, phase, args, kwargs, output, microbatch_id, batch_id, num_microbatches, forward_only, refcount=0
|
||||
) -> None:
|
||||
for attr_name in self.__slots__:
|
||||
setattr(self, attr_name, locals()[attr_name])
|
||||
|
||||
|
||||
class BackwardCache:
|
||||
__slots__ = ('checkpoint', 'stage_input_args', 'stage_input_kwargs', 'stage_outputs')
|
||||
__slots__ = ("checkpoint", "stage_input_args", "stage_input_kwargs", "stage_outputs")
|
||||
checkpoint: bool
|
||||
stage_input_args: Tuple[Any]
|
||||
stage_input_kwargs: Dict[Any, Any]
|
||||
stage_outputs: Tuple[Any]
|
||||
|
||||
def __init__(self,
|
||||
stage_input_args: Tuple[Any],
|
||||
stage_input_kwargs: Dict[Any, Any] = None,
|
||||
stage_outputs: Tuple[Any] = None,
|
||||
checkpoint: bool = False) -> None:
|
||||
def __init__(
|
||||
self,
|
||||
stage_input_args: Tuple[Any],
|
||||
stage_input_kwargs: Dict[Any, Any] = None,
|
||||
stage_outputs: Tuple[Any] = None,
|
||||
checkpoint: bool = False,
|
||||
) -> None:
|
||||
for arg_name in self.__slots__:
|
||||
setattr(self, arg_name, locals()[arg_name])
|
||||
|
||||
|
||||
class WorkerBase(ABC):
|
||||
|
||||
def __init__(self,
|
||||
partition_fn: Callable,
|
||||
partition_args: tuple,
|
||||
pp_rank: int,
|
||||
actual_stage_num: int,
|
||||
num_microbatches: int,
|
||||
device: str,
|
||||
criterion: Callable = None,
|
||||
metric: Callable = None,
|
||||
checkpoint: bool = False,
|
||||
data_process_func: Callable = None) -> None:
|
||||
def __init__(
|
||||
self,
|
||||
partition_fn: Callable,
|
||||
partition_args: tuple,
|
||||
pp_rank: int,
|
||||
actual_stage_num: int,
|
||||
num_microbatches: int,
|
||||
device: str,
|
||||
criterion: Callable = None,
|
||||
metric: Callable = None,
|
||||
checkpoint: bool = False,
|
||||
data_process_func: Callable = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.pp_rank = pp_rank
|
||||
@@ -150,11 +147,11 @@ class WorkerBase(ABC):
|
||||
self._initialize_context_container()
|
||||
|
||||
# main loop
|
||||
self.main_loop_thread = threading.Thread(target=self._work_loop, name=f'rank_{pp_rank}', daemon=True)
|
||||
self.main_loop_thread = threading.Thread(target=self._work_loop, name=f"rank_{pp_rank}", daemon=True)
|
||||
self.main_loop_thread.start()
|
||||
|
||||
def _get_future_by_device(self):
|
||||
return torch.futures.Future(devices=None if self.device in (None, 'cpu') else [self.device])
|
||||
return torch.futures.Future(devices=None if self.device in (None, "cpu") else [self.device])
|
||||
|
||||
def _initialize_outstanding_range(self):
|
||||
outstanding_range = None
|
||||
@@ -199,12 +196,13 @@ class WorkerBase(ABC):
|
||||
# lifecycle management for DAG scheduler
|
||||
if output_work_item.phase == Phase.FORWARD:
|
||||
lifecycle = len(self.get_consumer_stage_ids())
|
||||
if self.is_model_output(): # an extra reference for scheduler collecting results
|
||||
if self.is_model_output(): # an extra reference for scheduler collecting results
|
||||
lifecycle += 1
|
||||
elif output_work_item.phase == Phase.BACKWARD:
|
||||
lifecycle = len(self.get_producer_stage_ids())
|
||||
if self.is_model_input() and self._is_last_step(
|
||||
output_work_item): # an extra reference for ensure_backward
|
||||
output_work_item
|
||||
): # an extra reference for ensure_backward
|
||||
lifecycle += 1
|
||||
else:
|
||||
lifecycle = 0
|
||||
@@ -234,9 +232,9 @@ class WorkerBase(ABC):
|
||||
# offset supports get partial output to reduce comm costs.
|
||||
def get_output_by_key(self, key: UniqueKey, ref_use=False, rank=None, offsets=None) -> Any:
|
||||
output = self._get_output_all(key, ref_use, rank)
|
||||
if offsets is None: # get all for non iterable output
|
||||
if offsets is None: # get all for non iterable output
|
||||
return output
|
||||
else: # get part for iterable output
|
||||
else: # get part for iterable output
|
||||
output = [output[i] for i in offsets]
|
||||
return output
|
||||
|
||||
@@ -252,12 +250,12 @@ class WorkerBase(ABC):
|
||||
|
||||
def get_partition(self):
|
||||
with self.partition_condition_lock:
|
||||
self.partition_condition_lock.wait_for(lambda: hasattr(self, 'module_partition'))
|
||||
self.partition_condition_lock.wait_for(lambda: hasattr(self, "module_partition"))
|
||||
return self.module_partition
|
||||
|
||||
def get_partition_state_dict(self):
|
||||
with self.partition_condition_lock:
|
||||
self.partition_condition_lock.wait_for(lambda: hasattr(self, 'module_partition'))
|
||||
self.partition_condition_lock.wait_for(lambda: hasattr(self, "module_partition"))
|
||||
return self.module_partition.state_dict()
|
||||
|
||||
def _make_args_kwargs(self, microbatch, merge=False):
|
||||
@@ -293,8 +291,17 @@ class WorkerBase(ABC):
|
||||
# make args and kwargs
|
||||
args, kwargs = self._make_args_kwargs(microbatch)
|
||||
|
||||
work_item = WorkItem(self.pp_rank, Phase.FORWARD, args, kwargs, output, microbatch_id, None,
|
||||
self.num_microbatches, forward_only)
|
||||
work_item = WorkItem(
|
||||
self.pp_rank,
|
||||
Phase.FORWARD,
|
||||
args,
|
||||
kwargs,
|
||||
output,
|
||||
microbatch_id,
|
||||
None,
|
||||
self.num_microbatches,
|
||||
forward_only,
|
||||
)
|
||||
with self.work_list_condition_lock:
|
||||
self.work_list[key] = work_item
|
||||
self.work_list_condition_lock.notify_all()
|
||||
@@ -314,15 +321,25 @@ class WorkerBase(ABC):
|
||||
for off in self_input_offsets:
|
||||
self_arg_lst.append(arg_lst[off])
|
||||
|
||||
work_item = WorkItem(self.pp_rank, Phase.FORWARD, self_arg_lst, {}, output, microbatch_id, None,
|
||||
self.num_microbatches, forward_only)
|
||||
work_item = WorkItem(
|
||||
self.pp_rank,
|
||||
Phase.FORWARD,
|
||||
self_arg_lst,
|
||||
{},
|
||||
output,
|
||||
microbatch_id,
|
||||
None,
|
||||
self.num_microbatches,
|
||||
forward_only,
|
||||
)
|
||||
with self.work_list_condition_lock:
|
||||
self.work_list[key] = work_item
|
||||
self.work_list_condition_lock.notify_all()
|
||||
|
||||
# put input tensor which other nodes need into output_list as Phase.INPUT
|
||||
work_item_remote = WorkItem(self.pp_rank, Phase.INPUT, [], {}, arg_lst, microbatch_id, None,
|
||||
self.num_microbatches, forward_only)
|
||||
work_item_remote = WorkItem(
|
||||
self.pp_rank, Phase.INPUT, [], {}, arg_lst, microbatch_id, None, self.num_microbatches, forward_only
|
||||
)
|
||||
|
||||
with self.output_list_condition_lock:
|
||||
self.output_list[recv_input_key] = work_item_remote
|
||||
@@ -343,8 +360,17 @@ class WorkerBase(ABC):
|
||||
output = self._get_future_by_device()
|
||||
grad_wrt_loss = None
|
||||
|
||||
work_item = WorkItem(self.pp_rank, Phase.BACKWARD, grad_wrt_loss, {}, output, microbatch_id, None,
|
||||
self.num_microbatches, False)
|
||||
work_item = WorkItem(
|
||||
self.pp_rank,
|
||||
Phase.BACKWARD,
|
||||
grad_wrt_loss,
|
||||
{},
|
||||
output,
|
||||
microbatch_id,
|
||||
None,
|
||||
self.num_microbatches,
|
||||
False,
|
||||
)
|
||||
|
||||
self.work_list[key] = work_item
|
||||
self.work_list_condition_lock.notify_all()
|
||||
@@ -367,7 +393,7 @@ class WorkerBase(ABC):
|
||||
producer_stage_ids = self.get_producer_stage_ids()
|
||||
producer_num = len(producer_stage_ids)
|
||||
if self.need_model_input():
|
||||
producer_num += 1 # for input partition
|
||||
producer_num += 1 # for input partition
|
||||
subscribe_forward_futures: List[Future] = [None] * producer_num
|
||||
|
||||
# TODO(jiangziyue) get single value instead of the whole output
|
||||
@@ -376,9 +402,9 @@ class WorkerBase(ABC):
|
||||
producer_output_key = UniqueKey(microbatch_id, Phase.INPUT)
|
||||
producer_worker_rref = self.pp_rank_to_worker_rref[producer_stage_id]
|
||||
offsets = self._get_input_offsets_by_index(target_index=0)
|
||||
subscribe_forward_futures[0] = producer_worker_rref.rpc_async().get_output_by_key(producer_output_key,
|
||||
rank=self.pp_rank,
|
||||
offsets=offsets)
|
||||
subscribe_forward_futures[0] = producer_worker_rref.rpc_async().get_output_by_key(
|
||||
producer_output_key, rank=self.pp_rank, offsets=offsets
|
||||
)
|
||||
|
||||
for i in range(0, producer_num - 1):
|
||||
producer_stage_id = producer_stage_ids[i]
|
||||
@@ -386,11 +412,12 @@ class WorkerBase(ABC):
|
||||
producer_worker_rref = self.pp_rank_to_worker_rref[producer_stage_id]
|
||||
target_index = i + 1
|
||||
offsets = self._get_input_offsets_by_index(target_index=target_index)
|
||||
if offsets is not None and len(offsets) == 0: # no need to do rpc
|
||||
if offsets is not None and len(offsets) == 0: # no need to do rpc
|
||||
subscribe_forward_futures[target_index] = []
|
||||
else:
|
||||
subscribe_forward_futures[target_index] = producer_worker_rref.rpc_async().get_output_by_key(
|
||||
producer_output_key, rank=self.pp_rank, offsets=offsets)
|
||||
producer_output_key, rank=self.pp_rank, offsets=offsets
|
||||
)
|
||||
|
||||
else:
|
||||
for i in range(producer_num):
|
||||
@@ -399,14 +426,24 @@ class WorkerBase(ABC):
|
||||
producer_worker_rref = self.pp_rank_to_worker_rref[producer_stage_id]
|
||||
target_index = i
|
||||
offsets = self._get_input_offsets_by_index(target_index=target_index)
|
||||
if offsets is not None and len(offsets) == 0: # no need to do rpc
|
||||
if offsets is not None and len(offsets) == 0: # no need to do rpc
|
||||
subscribe_forward_futures[target_index] = []
|
||||
else:
|
||||
subscribe_forward_futures[target_index] = producer_worker_rref.rpc_async().get_output_by_key(
|
||||
producer_output_key, rank=self.pp_rank, offsets=offsets)
|
||||
producer_output_key, rank=self.pp_rank, offsets=offsets
|
||||
)
|
||||
|
||||
work_item_from_producer = WorkItem(stage_id, Phase.FORWARD, subscribe_forward_futures, {}, output,
|
||||
microbatch_id, None, self.num_microbatches, forward_only)
|
||||
work_item_from_producer = WorkItem(
|
||||
stage_id,
|
||||
Phase.FORWARD,
|
||||
subscribe_forward_futures,
|
||||
{},
|
||||
output,
|
||||
microbatch_id,
|
||||
None,
|
||||
self.num_microbatches,
|
||||
forward_only,
|
||||
)
|
||||
|
||||
return work_item_from_producer
|
||||
|
||||
@@ -441,15 +478,25 @@ class WorkerBase(ABC):
|
||||
consumer_worker_rref = self.pp_rank_to_worker_rref[consumer_stage_id]
|
||||
target_index = i
|
||||
offsets = self._get_output_offsets_by_index(target_index=target_index)
|
||||
if offsets is not None and len(offsets) == 0: # no need to do rpc
|
||||
if offsets is not None and len(offsets) == 0: # no need to do rpc
|
||||
subscribe_backward_futures[target_index] = []
|
||||
else:
|
||||
subscribe_backward_futures[target_index] = consumer_worker_rref.rpc_async().get_output_by_key(
|
||||
consumer_output_key, rank=self.pp_rank, offsets=offsets)
|
||||
consumer_output_key, rank=self.pp_rank, offsets=offsets
|
||||
)
|
||||
|
||||
# flatten args
|
||||
work_item_from_consumer = WorkItem(stage_id, Phase.BACKWARD, subscribe_backward_futures, {}, output,
|
||||
microbatch_id, None, self.num_microbatches, False)
|
||||
work_item_from_consumer = WorkItem(
|
||||
stage_id,
|
||||
Phase.BACKWARD,
|
||||
subscribe_backward_futures,
|
||||
{},
|
||||
output,
|
||||
microbatch_id,
|
||||
None,
|
||||
self.num_microbatches,
|
||||
False,
|
||||
)
|
||||
|
||||
return work_item_from_consumer
|
||||
|
||||
@@ -524,8 +571,8 @@ class WorkerBase(ABC):
|
||||
|
||||
def get_topo(self):
|
||||
with self.partition_condition_lock:
|
||||
self.partition_condition_lock.wait_for(lambda: hasattr(self, 'module_partition'))
|
||||
if hasattr(self.module_partition, '_topo'):
|
||||
self.partition_condition_lock.wait_for(lambda: hasattr(self, "module_partition"))
|
||||
if hasattr(self.module_partition, "_topo"):
|
||||
return self.module_partition._topo
|
||||
else:
|
||||
return None
|
||||
@@ -564,12 +611,12 @@ class WorkerBase(ABC):
|
||||
if stage_id == src_stage_id:
|
||||
src_index += i
|
||||
break
|
||||
else: # data from input partition
|
||||
else: # data from input partition
|
||||
src_index = 0
|
||||
# when output_len = 1, not iterable
|
||||
if target_index == src_index:
|
||||
if output_len == 1:
|
||||
res = None # offset = None to get all outputs
|
||||
res = None # offset = None to get all outputs
|
||||
return res
|
||||
else:
|
||||
res.append(src_offset)
|
||||
@@ -584,7 +631,6 @@ class WorkerBase(ABC):
|
||||
consumer_stage_ids = self.get_consumer_stage_ids()
|
||||
for val_list in output_vals:
|
||||
# An output may be passed to many down stages.
|
||||
target = None
|
||||
for val_pos in val_list.get():
|
||||
dst_partition_id = val_pos.partition_id
|
||||
dst_offset = val_pos.offset
|
||||
@@ -597,7 +643,7 @@ class WorkerBase(ABC):
|
||||
break
|
||||
if target_index == dst_index:
|
||||
if input_len == 1:
|
||||
res = None # offset = None to get all outputs
|
||||
res = None # offset = None to get all outputs
|
||||
return res
|
||||
else:
|
||||
res.append(dst_offset)
|
||||
@@ -623,7 +669,7 @@ class WorkerBase(ABC):
|
||||
flatten_args = []
|
||||
if self.is_first_stage():
|
||||
pytree_map(args_or_kwargs, fn=lambda x: flatten_args.append(x), map_all=True)
|
||||
else: # get by offset
|
||||
else: # get by offset
|
||||
topo: Topo = self.get_topo()
|
||||
self_partition_id = self.pp_rank_to_partition_id(self.pp_rank, topo)
|
||||
self_partition: Partition = topo.get_partition_by_id(self_partition_id)
|
||||
@@ -652,7 +698,7 @@ class WorkerBase(ABC):
|
||||
if stage_id == src_stage_id:
|
||||
src_index += i
|
||||
break
|
||||
else: # data from input partition
|
||||
else: # data from input partition
|
||||
src_index = 0
|
||||
# when output_len = 1, not iterable
|
||||
if output_len == 1:
|
||||
@@ -679,7 +725,7 @@ class WorkerBase(ABC):
|
||||
else:
|
||||
for i, arg in enumerate(args_or_kwargs):
|
||||
args_or_kwargs[i] = arg.wait()
|
||||
if args_or_kwargs is not None: # get by offset
|
||||
if args_or_kwargs is not None: # get by offset
|
||||
flatten_args = []
|
||||
topo: Topo = self.get_topo()
|
||||
self_partition_id = self.pp_rank_to_partition_id(self.pp_rank, topo)
|
||||
@@ -719,7 +765,7 @@ class WorkerBase(ABC):
|
||||
@abstractmethod
|
||||
def _get_work_item_key(self) -> UniqueKey:
|
||||
"""
|
||||
this method control the order of the microbatch to consume
|
||||
this method control the order of the microbatch to consume
|
||||
"""
|
||||
|
||||
def is_first_stage(self):
|
||||
@@ -761,7 +807,7 @@ class WorkerBase(ABC):
|
||||
kwargs = work_item.kwargs
|
||||
microbatch_id = work_item.microbatch_id
|
||||
forward_only = work_item.forward_only
|
||||
data_process_func = getattr(self, 'data_process_func', self._default_data_process_func)
|
||||
data_process_func = getattr(self, "data_process_func", self._default_data_process_func)
|
||||
consume_result = None
|
||||
|
||||
is_first_stage = self.is_first_stage()
|
||||
@@ -787,10 +833,12 @@ class WorkerBase(ABC):
|
||||
else:
|
||||
args_kwargs = self._get_real_args_kwargs_fwd(args)
|
||||
|
||||
args_kwargs = pyobj_map(args_kwargs, fn=lambda x: x.to(self.device).detach(),
|
||||
process_types=torch.Tensor) # torch rpc doesn't support args or rets in GPU
|
||||
args_kwargs = pyobj_map(args_kwargs, fn=lambda x: self.device,
|
||||
process_types=torch.device) # change devices from last stage to current device
|
||||
args_kwargs = pyobj_map(
|
||||
args_kwargs, fn=lambda x: x.to(self.device).detach(), process_types=torch.Tensor
|
||||
) # torch rpc doesn't support args or rets in GPU
|
||||
args_kwargs = pyobj_map(
|
||||
args_kwargs, fn=lambda x: self.device, process_types=torch.device
|
||||
) # change devices from last stage to current device
|
||||
|
||||
args, kwargs = data_process_func(args_kwargs)
|
||||
|
||||
@@ -851,16 +899,16 @@ class WorkerBase(ABC):
|
||||
use_checkpoint = False
|
||||
|
||||
if not forward_only:
|
||||
self.microbatch_id_to_backward_cache[microbatch_id] = BackwardCache(stage_input_args,
|
||||
stage_input_kwargs,
|
||||
stage_outputs,
|
||||
checkpoint=use_checkpoint)
|
||||
consume_result = pyobj_map(consume_result, fn=lambda x: x.to('cpu'),
|
||||
process_types=torch.Tensor) # torch rpc doesn't support args or rets in
|
||||
self.microbatch_id_to_backward_cache[microbatch_id] = BackwardCache(
|
||||
stage_input_args, stage_input_kwargs, stage_outputs, checkpoint=use_checkpoint
|
||||
)
|
||||
consume_result = pyobj_map(
|
||||
consume_result, fn=lambda x: x.to("cpu"), process_types=torch.Tensor
|
||||
) # torch rpc doesn't support args or rets in
|
||||
|
||||
# if not forward_only, do the backward
|
||||
if not forward_only:
|
||||
if is_last_stage: # if it is the last stage, trigger backward automatic
|
||||
if is_last_stage: # if it is the last stage, trigger backward automatic
|
||||
self._begin_backward(microbatch_id)
|
||||
|
||||
elif phase == Phase.BACKWARD:
|
||||
@@ -872,7 +920,9 @@ class WorkerBase(ABC):
|
||||
self.backward_times += 1
|
||||
self.outstanding -= 1
|
||||
|
||||
assert microbatch_id in self.microbatch_id_to_backward_cache, f"microbatch_id {microbatch_id} not in backward cache"
|
||||
assert (
|
||||
microbatch_id in self.microbatch_id_to_backward_cache
|
||||
), f"microbatch_id {microbatch_id} not in backward cache"
|
||||
backward_cache = self.microbatch_id_to_backward_cache.pop(microbatch_id)
|
||||
|
||||
stage_outputs = backward_cache.stage_outputs
|
||||
@@ -906,8 +956,9 @@ class WorkerBase(ABC):
|
||||
filtered_grads.append(grad)
|
||||
|
||||
stage_outputs = filtered_outputs
|
||||
grad_tensors = pyobj_map(filtered_grads, fn=lambda x: x.to(self.device),
|
||||
process_types=torch.Tensor) # torch rpc doesn't support args or rets in GPU
|
||||
grad_tensors = pyobj_map(
|
||||
filtered_grads, fn=lambda x: x.to(self.device), process_types=torch.Tensor
|
||||
) # torch rpc doesn't support args or rets in GPU
|
||||
autograd.backward(stage_outputs, grad_tensors=grad_tensors)
|
||||
|
||||
# collect grad of input tensor
|
||||
@@ -920,8 +971,8 @@ class WorkerBase(ABC):
|
||||
else:
|
||||
consume_result.append(None)
|
||||
consume_result = pyobj_map(
|
||||
consume_result, fn=lambda x: x.to('cpu'),
|
||||
process_types=torch.Tensor) # torch rpc doesn't support args or rets in GPU
|
||||
consume_result, fn=lambda x: x.to("cpu"), process_types=torch.Tensor
|
||||
) # torch rpc doesn't support args or rets in GPU
|
||||
|
||||
else:
|
||||
raise TypeError(f"Unknown phase appears in _consume_work_item_by_phase {phase}")
|
||||
@@ -929,7 +980,7 @@ class WorkerBase(ABC):
|
||||
return consume_result
|
||||
|
||||
def _get_store_len(self):
|
||||
return f'work_list:{len(self.work_list)} output_list:{len(self.output_list)} backward_cache:{len(self.microbatch_id_to_backward_cache)} label_cache:{len(self.microbatch_id_to_labels)}'
|
||||
return f"work_list:{len(self.work_list)} output_list:{len(self.output_list)} backward_cache:{len(self.microbatch_id_to_backward_cache)} label_cache:{len(self.microbatch_id_to_labels)}"
|
||||
|
||||
def _get_parameter_grad_sum(self):
|
||||
grad_sum = 0
|
||||
@@ -1014,19 +1065,20 @@ class WorkerBase(ABC):
|
||||
|
||||
|
||||
class PipelineEngineBase(ABC, nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
worker_type,
|
||||
partition_fn: Callable,
|
||||
stage_num,
|
||||
num_microbatches,
|
||||
device: str,
|
||||
use_1F1B=False,
|
||||
chunk: int = 1,
|
||||
criterion: Callable = None,
|
||||
metric: Callable = None,
|
||||
checkpoint: bool = False,
|
||||
data_process_func: Callable = None) -> None:
|
||||
def __init__(
|
||||
self,
|
||||
worker_type,
|
||||
partition_fn: Callable,
|
||||
stage_num,
|
||||
num_microbatches,
|
||||
device: str,
|
||||
use_1F1B=False,
|
||||
chunk: int = 1,
|
||||
criterion: Callable = None,
|
||||
metric: Callable = None,
|
||||
checkpoint: bool = False,
|
||||
data_process_func: Callable = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.worker_type = worker_type
|
||||
self.partition_fn: Callable = partition_fn
|
||||
@@ -1056,12 +1108,12 @@ class PipelineEngineBase(ABC, nn.Module):
|
||||
data_process_func = self.data_process_func
|
||||
if data_process_func is not None:
|
||||
assert callable(data_process_func), "data_process_func must be a function"
|
||||
assert '<locals>' not in data_process_func.__repr__(), "data_process_func must be a global function"
|
||||
assert '<lambda>' not in data_process_func.__repr__(), "data_process_func cannot be a lambda expression"
|
||||
assert "<locals>" not in data_process_func.__repr__(), "data_process_func must be a global function"
|
||||
assert "<lambda>" not in data_process_func.__repr__(), "data_process_func cannot be a lambda expression"
|
||||
sig = inspect.signature(data_process_func)
|
||||
assert len(
|
||||
sig.parameters
|
||||
) == 2, f"length of data_process_func' arguments must be 2, receive {len(sig.parameters)} arguments instead"
|
||||
assert (
|
||||
len(sig.parameters) == 2
|
||||
), f"length of data_process_func' arguments must be 2, receive {len(sig.parameters)} arguments instead"
|
||||
|
||||
def _get_actual_stage_num(self) -> int:
|
||||
return self.stage_num if self.chunk == 1 else self.virtual_stage_num
|
||||
@@ -1104,19 +1156,33 @@ class PipelineEngineBase(ABC, nn.Module):
|
||||
partition_id = self.pp_rank_to_module_partition_id[pp_rank]
|
||||
partition_args = (partition_id, chunk, actual_stage_num)
|
||||
rpc_worker_id = self.pp_rank_to_rpc_worker_id[pp_rank]
|
||||
if device[:4] == 'cuda':
|
||||
device = f'cuda:{rpc_worker_id}'
|
||||
self.pp_rank_to_worker_rref[pp_rank] = rpc.remote(rpc_worker_id,
|
||||
worker_type,
|
||||
args=(partition_fn, partition_args, pp_rank,
|
||||
actual_stage_num, num_microbatches, device,
|
||||
criterion, metric, checkpoint, data_process_func))
|
||||
if device[:4] == "cuda":
|
||||
device = f"cuda:{rpc_worker_id}"
|
||||
self.pp_rank_to_worker_rref[pp_rank] = rpc.remote(
|
||||
rpc_worker_id,
|
||||
worker_type,
|
||||
args=(
|
||||
partition_fn,
|
||||
partition_args,
|
||||
pp_rank,
|
||||
actual_stage_num,
|
||||
num_microbatches,
|
||||
device,
|
||||
criterion,
|
||||
metric,
|
||||
checkpoint,
|
||||
data_process_func,
|
||||
),
|
||||
)
|
||||
|
||||
# let each worker know global worker rref (include itself)
|
||||
sync_futs = []
|
||||
for pp_rank in self.pp_rank_to_worker_rref:
|
||||
fut = self.pp_rank_to_worker_rref[pp_rank].rpc_async(timeout=0).sync_global_worker_rrefs(
|
||||
self.pp_rank_to_worker_rref)
|
||||
fut = (
|
||||
self.pp_rank_to_worker_rref[pp_rank]
|
||||
.rpc_async(timeout=0)
|
||||
.sync_global_worker_rrefs(self.pp_rank_to_worker_rref)
|
||||
)
|
||||
sync_futs.append(fut)
|
||||
|
||||
for fut in sync_futs:
|
||||
@@ -1157,8 +1223,9 @@ class PipelineEngineBase(ABC, nn.Module):
|
||||
def get_output_pp_ranks(self) -> List[int]:
|
||||
return [self._get_actual_stage_num() - 1]
|
||||
|
||||
def _consume_constraint(self, microbatch_id: int, forward_only: bool, input_pp_ranks: List[int],
|
||||
output_pp_ranks: List[int], ret_future):
|
||||
def _consume_constraint(
|
||||
self, microbatch_id: int, forward_only: bool, input_pp_ranks: List[int], output_pp_ranks: List[int], ret_future
|
||||
):
|
||||
actual_stage_num = self._get_actual_stage_num()
|
||||
use_1F1B = self.use_1F1B
|
||||
if microbatch_id >= actual_stage_num:
|
||||
@@ -1206,7 +1273,8 @@ class PipelineEngineBase(ABC, nn.Module):
|
||||
worker_rref = self.pp_rank_to_worker_rref[pp_rank]
|
||||
key = UniqueKey(self.num_microbatches - 1, Phase.BACKWARD)
|
||||
fut = worker_rref.rpc_async().get_output_by_key(
|
||||
key, offsets=[]) # only ensure the res exists, no need for real data.
|
||||
key, offsets=[]
|
||||
) # only ensure the res exists, no need for real data.
|
||||
backward_result.append(fut)
|
||||
|
||||
for fut in backward_result:
|
||||
@@ -1244,11 +1312,14 @@ class PipelineEngineBase(ABC, nn.Module):
|
||||
|
||||
if labels is not None and not forward_only:
|
||||
assert hasattr(
|
||||
self, 'optimizer_class'), "call `initialize_optimizer` to initialize optimizer before forward_backward"
|
||||
self, "optimizer_class"
|
||||
), "call `initialize_optimizer` to initialize optimizer before forward_backward"
|
||||
|
||||
num_microbatches = self.num_microbatches
|
||||
|
||||
assert batch_length >= num_microbatches, "num_microbatches is greater than the size of a batch, which is illegal"
|
||||
assert (
|
||||
batch_length >= num_microbatches
|
||||
), "num_microbatches is greater than the size of a batch, which is illegal"
|
||||
microbatch_size = math.ceil(batch_length / num_microbatches)
|
||||
device = self.device
|
||||
|
||||
@@ -1285,10 +1356,10 @@ class PipelineEngineBase(ABC, nn.Module):
|
||||
# collect forward result
|
||||
forward_result = self._collect_forward_result(output_pp_ranks, ret_future)
|
||||
|
||||
if not forward_only and hasattr(self, 'optimizer_class'):
|
||||
if not forward_only and hasattr(self, "optimizer_class"):
|
||||
self.step()
|
||||
|
||||
self._reset_worker() # reset worker attributes for next batch
|
||||
self._reset_worker() # reset worker attributes for next batch
|
||||
return forward_result
|
||||
|
||||
def initialize_optimizer(self, optimizer_class: type, **kwargs):
|
||||
|
@@ -2,7 +2,6 @@ import threading
|
||||
from typing import Callable, Dict, List
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from torch._C._distributed_rpc import PyRRef
|
||||
from torch.futures import Future
|
||||
|
||||
@@ -15,7 +14,6 @@ from colossalai.legacy.pipeline.rpc._pipeline_base import Phase, PipelineEngineB
|
||||
|
||||
|
||||
class FillDrainWorker(WorkerBase):
|
||||
|
||||
def _get_work_item_key(self) -> UniqueKey:
|
||||
# execute backward first (if backward phase in work_list)
|
||||
num_microbatches = self.num_microbatches
|
||||
@@ -33,29 +31,40 @@ class FillDrainWorker(WorkerBase):
|
||||
|
||||
|
||||
class FillDrainPipelineEngine(PipelineEngineBase):
|
||||
|
||||
def __init__(self,
|
||||
partition_fn: Callable,
|
||||
stage_num: int,
|
||||
num_microbatches: int,
|
||||
device: str,
|
||||
chunk: int = 1,
|
||||
criterion: Callable = None,
|
||||
metric: Callable = None,
|
||||
checkpoint: bool = False,
|
||||
data_process_func: Callable = None) -> None:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
partition_fn: Callable,
|
||||
stage_num: int,
|
||||
num_microbatches: int,
|
||||
device: str,
|
||||
chunk: int = 1,
|
||||
criterion: Callable = None,
|
||||
metric: Callable = None,
|
||||
checkpoint: bool = False,
|
||||
data_process_func: Callable = None,
|
||||
) -> None:
|
||||
if chunk > 1:
|
||||
assert num_microbatches % stage_num == 0, \
|
||||
"if you use interleaving strategy, make sure 'num_microbatches' is a multiple of stage_num!"
|
||||
assert (
|
||||
num_microbatches % stage_num == 0
|
||||
), "if you use interleaving strategy, make sure 'num_microbatches' is a multiple of stage_num!"
|
||||
use_1F1B = False
|
||||
|
||||
super().__init__(FillDrainWorker, partition_fn, stage_num, num_microbatches, device, use_1F1B, chunk, criterion,
|
||||
metric, checkpoint, data_process_func)
|
||||
super().__init__(
|
||||
FillDrainWorker,
|
||||
partition_fn,
|
||||
stage_num,
|
||||
num_microbatches,
|
||||
device,
|
||||
use_1F1B,
|
||||
chunk,
|
||||
criterion,
|
||||
metric,
|
||||
checkpoint,
|
||||
data_process_func,
|
||||
)
|
||||
|
||||
|
||||
class OneFOneBWorker(WorkerBase):
|
||||
|
||||
def _get_work_item_key(self) -> UniqueKey:
|
||||
# execute backward first (if backward phase in work_list)
|
||||
pp_rank = self.pp_rank
|
||||
@@ -77,8 +86,7 @@ class OneFOneBWorker(WorkerBase):
|
||||
# change outstanding_range at:
|
||||
# 1. forward times reach actual_stage_num, this is the end of continuous forward
|
||||
# 2. forward times reach num_microbatches, this is the end of 1F1B mode
|
||||
if not is_last_stage and \
|
||||
target_key.phase == Phase.FORWARD:
|
||||
if not is_last_stage and target_key.phase == Phase.FORWARD:
|
||||
if target_key.microbatch_id == actual_stage_num - 1 and num_microbatches > 2:
|
||||
# Why need num_microbatches > 2 ? Because there is no steady stage when num_microbatches <= 2
|
||||
outstanding_min = actual_stage_num - pp_rank - 1
|
||||
@@ -91,30 +99,41 @@ class OneFOneBWorker(WorkerBase):
|
||||
|
||||
|
||||
class OneFOneBPipelineEngine(PipelineEngineBase):
|
||||
|
||||
def __init__(self,
|
||||
partition_fn: Callable,
|
||||
stage_num: int,
|
||||
num_microbatches: int,
|
||||
device: str,
|
||||
chunk: int = 1,
|
||||
criterion: Callable = None,
|
||||
metric: Callable = None,
|
||||
checkpoint: bool = False,
|
||||
data_process_func: Callable = None) -> None:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
partition_fn: Callable,
|
||||
stage_num: int,
|
||||
num_microbatches: int,
|
||||
device: str,
|
||||
chunk: int = 1,
|
||||
criterion: Callable = None,
|
||||
metric: Callable = None,
|
||||
checkpoint: bool = False,
|
||||
data_process_func: Callable = None,
|
||||
) -> None:
|
||||
if chunk > 1:
|
||||
assert num_microbatches % stage_num == 0, \
|
||||
"if you use interleaving strategy, make sure 'num_microbatches' is a multiple of stage_num!"
|
||||
assert (
|
||||
num_microbatches % stage_num == 0
|
||||
), "if you use interleaving strategy, make sure 'num_microbatches' is a multiple of stage_num!"
|
||||
# assert num_microbatches > stage_num * chunk, "num_microbatches must be greater than stage_num * chunk"
|
||||
use_1F1B = True
|
||||
|
||||
super().__init__(OneFOneBWorker, partition_fn, stage_num, num_microbatches, device, use_1F1B, chunk, criterion,
|
||||
metric, checkpoint, data_process_func)
|
||||
super().__init__(
|
||||
OneFOneBWorker,
|
||||
partition_fn,
|
||||
stage_num,
|
||||
num_microbatches,
|
||||
device,
|
||||
use_1F1B,
|
||||
chunk,
|
||||
criterion,
|
||||
metric,
|
||||
checkpoint,
|
||||
data_process_func,
|
||||
)
|
||||
|
||||
|
||||
class ChimeraWorker(WorkerBase):
|
||||
|
||||
def _get_producer_consumer(self) -> None:
|
||||
rank = self.pp_rank
|
||||
min_pp_rank = (rank // self.actual_stage_num) * self.actual_stage_num
|
||||
@@ -143,11 +162,12 @@ class ChimeraWorker(WorkerBase):
|
||||
forward_block_size = 1 if self.num_microbatches < stage_num else self.num_microbatches // stage_num
|
||||
forward_block_num = self.forward_times // forward_block_size
|
||||
|
||||
if self.forward_times >= real_microbatch_num or \
|
||||
((pp_rank + 1) % stage_num == 0 and forward_block_num > self.backward_times):
|
||||
if self.forward_times >= real_microbatch_num or (
|
||||
(pp_rank + 1) % stage_num == 0 and forward_block_num > self.backward_times
|
||||
):
|
||||
target_phase = Phase.BACKWARD
|
||||
target_microbatch_id = self.backward_times
|
||||
else: # others
|
||||
else: # others
|
||||
target_phase = Phase.FORWARD
|
||||
target_microbatch_id = self.forward_times
|
||||
|
||||
@@ -168,7 +188,7 @@ class ChimeraWorker(WorkerBase):
|
||||
# from corresponding up stage
|
||||
pp_rank = self.pp_rank
|
||||
stage_num = self.actual_stage_num
|
||||
device = self.device
|
||||
self.device
|
||||
if pp_rank < stage_num:
|
||||
super()._initialize_partition()
|
||||
else:
|
||||
@@ -242,27 +262,38 @@ class ChimeraWorker(WorkerBase):
|
||||
|
||||
|
||||
class ChimeraPipelineEngine(PipelineEngineBase):
|
||||
|
||||
def __init__(self,
|
||||
partition_fn: Callable,
|
||||
stage_num: int,
|
||||
num_microbatches: int,
|
||||
device: str,
|
||||
criterion: Callable = None,
|
||||
metric: Callable = None,
|
||||
checkpoint: bool = False,
|
||||
data_process_func: Callable = None) -> None:
|
||||
|
||||
assert num_microbatches % stage_num == 0, \
|
||||
"In Chimera, num_microbatches must be the multiply of stage_num!"
|
||||
def __init__(
|
||||
self,
|
||||
partition_fn: Callable,
|
||||
stage_num: int,
|
||||
num_microbatches: int,
|
||||
device: str,
|
||||
criterion: Callable = None,
|
||||
metric: Callable = None,
|
||||
checkpoint: bool = False,
|
||||
data_process_func: Callable = None,
|
||||
) -> None:
|
||||
assert num_microbatches % stage_num == 0, "In Chimera, num_microbatches must be the multiply of stage_num!"
|
||||
use_1F1B = False
|
||||
chunk = 1
|
||||
|
||||
super().__init__(ChimeraWorker, partition_fn, stage_num, num_microbatches, device, use_1F1B, chunk, criterion,
|
||||
metric, checkpoint, data_process_func)
|
||||
super().__init__(
|
||||
ChimeraWorker,
|
||||
partition_fn,
|
||||
stage_num,
|
||||
num_microbatches,
|
||||
device,
|
||||
use_1F1B,
|
||||
chunk,
|
||||
criterion,
|
||||
metric,
|
||||
checkpoint,
|
||||
data_process_func,
|
||||
)
|
||||
|
||||
def _consume_constraint(self, microbatch_id: int, forward_only: bool, input_pp_ranks: List[int],
|
||||
output_pp_ranks: List[int], ret_future):
|
||||
def _consume_constraint(
|
||||
self, microbatch_id: int, forward_only: bool, input_pp_ranks: List[int], output_pp_ranks: List[int], ret_future
|
||||
):
|
||||
pass
|
||||
|
||||
def _create_pp_rank_to_rpc_worker_id(self) -> None:
|
||||
|
@@ -1,7 +1,7 @@
|
||||
import argparse
|
||||
import os
|
||||
import warnings
|
||||
from typing import Any, Callable, Dict, List, Tuple, Type, Union
|
||||
from typing import Any, Callable, Tuple, Type, Union
|
||||
|
||||
import torch
|
||||
import torch.distributed.rpc as rpc
|
||||
@@ -61,7 +61,7 @@ def get_batch_lengths(batch):
|
||||
|
||||
|
||||
def split_batch(batch: Any, start, stop, device: str):
|
||||
if device == 'cuda':
|
||||
if device == "cuda":
|
||||
fn = lambda x: x[start:stop].cuda()
|
||||
else:
|
||||
fn = lambda x: x[start:stop]
|
||||
@@ -102,8 +102,8 @@ def get_real_args_kwargs(args_or_kwargs):
|
||||
|
||||
|
||||
def run_worker(rank, args, master_func):
|
||||
os.environ['MASTER_ADDR'] = args.master_addr
|
||||
os.environ['MASTER_PORT'] = args.master_port
|
||||
os.environ["MASTER_ADDR"] = args.master_addr
|
||||
os.environ["MASTER_PORT"] = args.master_port
|
||||
|
||||
device = args.device
|
||||
world_size = args.world_size
|
||||
@@ -112,15 +112,17 @@ def run_worker(rank, args, master_func):
|
||||
num_worker_threads = args.num_worker_threads
|
||||
host = args.master_addr
|
||||
port = args.master_port
|
||||
backend = 'nccl' if device == 'cuda' else 'gloo'
|
||||
backend = "nccl" if device == "cuda" else "gloo"
|
||||
|
||||
launch(dict(), rank, world_size, host, int(port), backend, verbose=False)
|
||||
ppg.set_global_info(rank=rank,
|
||||
world_size=world_size,
|
||||
dp_degree=dp_degree,
|
||||
tp_degree=tp_degree,
|
||||
num_worker_threads=num_worker_threads,
|
||||
device=device)
|
||||
ppg.set_global_info(
|
||||
rank=rank,
|
||||
world_size=world_size,
|
||||
dp_degree=dp_degree,
|
||||
tp_degree=tp_degree,
|
||||
num_worker_threads=num_worker_threads,
|
||||
device=device,
|
||||
)
|
||||
ppg.args = args
|
||||
# in rpc mode, only rank 0 is needed to be coded
|
||||
if rank == 0:
|
||||
@@ -139,17 +141,17 @@ def rpc_run(args, master_func):
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--epoch', type=int, default=1)
|
||||
parser.add_argument('--world_size', type=int, default=2)
|
||||
parser.add_argument('--batch_size', type=int, default=16)
|
||||
parser.add_argument('--dp_degree', type=int, default=1)
|
||||
parser.add_argument('--tp_degree', type=int, default=1)
|
||||
parser.add_argument('--num_microbatches', type=int, default=2)
|
||||
parser.add_argument('--chunk', type=int, default=1)
|
||||
parser.add_argument('--use_checkpoint', action='store_true')
|
||||
parser.add_argument('--optimizer', type=str, choices=['SGD', 'Adam', 'RMSprop'], default='SGD')
|
||||
parser.add_argument('--device', type=str, choices=['cpu', 'cuda'], default='cuda')
|
||||
parser.add_argument('--master_addr', type=str, default='localhost')
|
||||
parser.add_argument('--master_port', type=str, default='29020')
|
||||
parser.add_argument('--num_worker_threads', type=int, default=128)
|
||||
parser.add_argument("--epoch", type=int, default=1)
|
||||
parser.add_argument("--world_size", type=int, default=2)
|
||||
parser.add_argument("--batch_size", type=int, default=16)
|
||||
parser.add_argument("--dp_degree", type=int, default=1)
|
||||
parser.add_argument("--tp_degree", type=int, default=1)
|
||||
parser.add_argument("--num_microbatches", type=int, default=2)
|
||||
parser.add_argument("--chunk", type=int, default=1)
|
||||
parser.add_argument("--use_checkpoint", action="store_true")
|
||||
parser.add_argument("--optimizer", type=str, choices=["SGD", "Adam", "RMSprop"], default="SGD")
|
||||
parser.add_argument("--device", type=str, choices=["cpu", "cuda"], default="cuda")
|
||||
parser.add_argument("--master_addr", type=str, default="localhost")
|
||||
parser.add_argument("--master_port", type=str, default="29020")
|
||||
parser.add_argument("--num_worker_threads", type=int, default=128)
|
||||
return parser.parse_args()
|
||||
|
@@ -38,8 +38,7 @@ def _binary_partition(weights: List, start: int, end: int):
|
||||
|
||||
|
||||
def _heap_addition(weights: List, intervals: int, add_cnt: int):
|
||||
"""
|
||||
"""
|
||||
""" """
|
||||
|
||||
def _heap_push(heap, st, ed):
|
||||
value = weights[ed - 1]
|
||||
@@ -113,8 +112,9 @@ def _binary_search(weights, num):
|
||||
|
||||
|
||||
def partition_uniform(num_items, pipeline_parallel_size, num_chunks):
|
||||
assert num_items % num_chunks == 0, \
|
||||
"Layer length should be divided by the number of chunks, otherwise parameter method is recommended"
|
||||
assert (
|
||||
num_items % num_chunks == 0
|
||||
), "Layer length should be divided by the number of chunks, otherwise parameter method is recommended"
|
||||
|
||||
logger = get_dist_logger()
|
||||
parts = [[] for _ in range(pipeline_parallel_size)]
|
||||
@@ -162,7 +162,7 @@ def build_kwargs_for_module(function, input_tensor, kw_dict):
|
||||
elif isinstance(input_tensor, torch.Tensor):
|
||||
kwargs_offset = 1
|
||||
elif isinstance(input_tensor, (tuple, OrderedDict)):
|
||||
#assert isinstance(input_tensor, tuple), f'input_tensor should be a torch.Tensor or a tuple object.'
|
||||
# assert isinstance(input_tensor, tuple), f'input_tensor should be a torch.Tensor or a tuple object.'
|
||||
# Huggingface will take their own structures based on OrderedDict as the output
|
||||
# between layers so we've to close this check.
|
||||
kwargs_offset = len(input_tensor)
|
||||
@@ -204,21 +204,21 @@ def exec_func_with_kwargs(func, kw_dict, input_tensor, kwargs):
|
||||
kwargs[k] = rst
|
||||
return input_tensor
|
||||
if isinstance(input_tensor, tuple):
|
||||
assert len(input_tensor) > 0, f'input_tensor should not be empty, when kw_dict is None.'
|
||||
assert len(input_tensor) > 0, f"input_tensor should not be empty, when kw_dict is None."
|
||||
sig = inspect.signature(func)
|
||||
func_args_num = len(sig.parameters)
|
||||
assert func_args_num <= len(
|
||||
input_tensor), f'func requires {func_args_num} arguments, but input_tensors only have {len(input_tensor)}.'
|
||||
input_tensor
|
||||
), f"func requires {func_args_num} arguments, but input_tensors only have {len(input_tensor)}."
|
||||
if func_args_num < len(input_tensor):
|
||||
return func(*input_tensor[:func_args_num])
|
||||
else:
|
||||
return func(*input_tensor)
|
||||
assert isinstance(input_tensor, torch.Tensor), 'input_tensor should be a type of torch.Tensor or tuple.'
|
||||
assert isinstance(input_tensor, torch.Tensor), "input_tensor should be a type of torch.Tensor or tuple."
|
||||
return func(input_tensor)
|
||||
|
||||
|
||||
def exec_funcs_with_kwargs(func_dict, func_key, input_tensor, kwargs):
|
||||
|
||||
assert func_key in func_dict, f"{func_key} is not in the function_dict."
|
||||
funcs_to_exec = func_dict[func_key]
|
||||
if isinstance(funcs_to_exec, list):
|
||||
@@ -243,7 +243,7 @@ def call_module(module, args=None, kwargs=None):
|
||||
forward_func = module.forward
|
||||
sig = inspect.signature(forward_func)
|
||||
param_nums = len(sig.parameters)
|
||||
feed_nums = len(args) + len(kwargs)
|
||||
len(args) + len(kwargs)
|
||||
args_needed_nums = param_nums - len(kwargs)
|
||||
args_needed = args[:args_needed_nums]
|
||||
if isinstance(module, CheckpointModule):
|
||||
@@ -256,17 +256,17 @@ def call_module(module, args=None, kwargs=None):
|
||||
|
||||
|
||||
def customized_partition(exec_seq):
|
||||
'''
|
||||
"""
|
||||
This function will analyze the exec_seq. In the exec_seq, users will use 'SPLIT_NODE' as an
|
||||
annotation to note the partition point.
|
||||
'''
|
||||
"""
|
||||
customized_parts = {}
|
||||
start = 0
|
||||
stop = 0
|
||||
rank = 0
|
||||
for element in exec_seq:
|
||||
if isinstance(element, str):
|
||||
if element == 'SPLIT_NODE':
|
||||
if element == "SPLIT_NODE":
|
||||
customized_parts[rank] = [(start, stop)]
|
||||
start = stop
|
||||
rank += 1
|
||||
|
Reference in New Issue
Block a user