mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-05 02:51:59 +00:00
[misc] update pre-commit and run all files (#4752)
* [misc] update pre-commit * [misc] run pre-commit * [misc] remove useless configuration files * [misc] ignore cuda for clang-format
This commit is contained in:
@@ -12,17 +12,9 @@ from torch import autograd, nn, optim
|
||||
from torch._C._distributed_rpc import PyRRef
|
||||
from torch.futures import Future
|
||||
|
||||
from colossalai.legacy.pipeline.middleware import Partition, PartitionInputVal, PartitionOutputVal, Topo
|
||||
from colossalai.legacy.pipeline.middleware import Partition, Topo
|
||||
from colossalai.legacy.pipeline.pipeline_process_group import ppg
|
||||
from colossalai.legacy.pipeline.rpc.utils import (
|
||||
get_batch_lengths,
|
||||
pyobj_map,
|
||||
pytree_filter,
|
||||
pytree_map,
|
||||
split_batch,
|
||||
tensor_shape_list,
|
||||
type_detail,
|
||||
)
|
||||
from colossalai.legacy.pipeline.rpc.utils import get_batch_lengths, pyobj_map, pytree_filter, pytree_map, split_batch
|
||||
|
||||
|
||||
class Phase(Enum):
|
||||
@@ -33,7 +25,7 @@ class Phase(Enum):
|
||||
|
||||
|
||||
class UniqueKey:
|
||||
__slots__ = ('microbatch_id', 'phase')
|
||||
__slots__ = ("microbatch_id", "phase")
|
||||
microbatch_id: int
|
||||
phase: Phase
|
||||
|
||||
@@ -48,12 +40,22 @@ class UniqueKey:
|
||||
return tuple.__hash__((self.microbatch_id, self.phase))
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f'Key(microbatch_id={self.microbatch_id}, phase={self.phase})'
|
||||
return f"Key(microbatch_id={self.microbatch_id}, phase={self.phase})"
|
||||
|
||||
|
||||
class WorkItem:
|
||||
__slots__ = ('stage_id', 'phase', 'args', 'kwargs', 'output', 'refcount', 'microbatch_id', 'batch_id',
|
||||
'num_microbatches', 'forward_only')
|
||||
__slots__ = (
|
||||
"stage_id",
|
||||
"phase",
|
||||
"args",
|
||||
"kwargs",
|
||||
"output",
|
||||
"refcount",
|
||||
"microbatch_id",
|
||||
"batch_id",
|
||||
"num_microbatches",
|
||||
"forward_only",
|
||||
)
|
||||
|
||||
stage_id: int
|
||||
phase: Phase
|
||||
@@ -66,50 +68,45 @@ class WorkItem:
|
||||
num_microbatches: int
|
||||
forward_only: bool
|
||||
|
||||
def __init__(self,
|
||||
stage_id,
|
||||
phase,
|
||||
args,
|
||||
kwargs,
|
||||
output,
|
||||
microbatch_id,
|
||||
batch_id,
|
||||
num_microbatches,
|
||||
forward_only,
|
||||
refcount=0) -> None:
|
||||
def __init__(
|
||||
self, stage_id, phase, args, kwargs, output, microbatch_id, batch_id, num_microbatches, forward_only, refcount=0
|
||||
) -> None:
|
||||
for attr_name in self.__slots__:
|
||||
setattr(self, attr_name, locals()[attr_name])
|
||||
|
||||
|
||||
class BackwardCache:
|
||||
__slots__ = ('checkpoint', 'stage_input_args', 'stage_input_kwargs', 'stage_outputs')
|
||||
__slots__ = ("checkpoint", "stage_input_args", "stage_input_kwargs", "stage_outputs")
|
||||
checkpoint: bool
|
||||
stage_input_args: Tuple[Any]
|
||||
stage_input_kwargs: Dict[Any, Any]
|
||||
stage_outputs: Tuple[Any]
|
||||
|
||||
def __init__(self,
|
||||
stage_input_args: Tuple[Any],
|
||||
stage_input_kwargs: Dict[Any, Any] = None,
|
||||
stage_outputs: Tuple[Any] = None,
|
||||
checkpoint: bool = False) -> None:
|
||||
def __init__(
|
||||
self,
|
||||
stage_input_args: Tuple[Any],
|
||||
stage_input_kwargs: Dict[Any, Any] = None,
|
||||
stage_outputs: Tuple[Any] = None,
|
||||
checkpoint: bool = False,
|
||||
) -> None:
|
||||
for arg_name in self.__slots__:
|
||||
setattr(self, arg_name, locals()[arg_name])
|
||||
|
||||
|
||||
class WorkerBase(ABC):
|
||||
|
||||
def __init__(self,
|
||||
partition_fn: Callable,
|
||||
partition_args: tuple,
|
||||
pp_rank: int,
|
||||
actual_stage_num: int,
|
||||
num_microbatches: int,
|
||||
device: str,
|
||||
criterion: Callable = None,
|
||||
metric: Callable = None,
|
||||
checkpoint: bool = False,
|
||||
data_process_func: Callable = None) -> None:
|
||||
def __init__(
|
||||
self,
|
||||
partition_fn: Callable,
|
||||
partition_args: tuple,
|
||||
pp_rank: int,
|
||||
actual_stage_num: int,
|
||||
num_microbatches: int,
|
||||
device: str,
|
||||
criterion: Callable = None,
|
||||
metric: Callable = None,
|
||||
checkpoint: bool = False,
|
||||
data_process_func: Callable = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.pp_rank = pp_rank
|
||||
@@ -150,11 +147,11 @@ class WorkerBase(ABC):
|
||||
self._initialize_context_container()
|
||||
|
||||
# main loop
|
||||
self.main_loop_thread = threading.Thread(target=self._work_loop, name=f'rank_{pp_rank}', daemon=True)
|
||||
self.main_loop_thread = threading.Thread(target=self._work_loop, name=f"rank_{pp_rank}", daemon=True)
|
||||
self.main_loop_thread.start()
|
||||
|
||||
def _get_future_by_device(self):
|
||||
return torch.futures.Future(devices=None if self.device in (None, 'cpu') else [self.device])
|
||||
return torch.futures.Future(devices=None if self.device in (None, "cpu") else [self.device])
|
||||
|
||||
def _initialize_outstanding_range(self):
|
||||
outstanding_range = None
|
||||
@@ -199,12 +196,13 @@ class WorkerBase(ABC):
|
||||
# lifecycle management for DAG scheduler
|
||||
if output_work_item.phase == Phase.FORWARD:
|
||||
lifecycle = len(self.get_consumer_stage_ids())
|
||||
if self.is_model_output(): # an extra reference for scheduler collecting results
|
||||
if self.is_model_output(): # an extra reference for scheduler collecting results
|
||||
lifecycle += 1
|
||||
elif output_work_item.phase == Phase.BACKWARD:
|
||||
lifecycle = len(self.get_producer_stage_ids())
|
||||
if self.is_model_input() and self._is_last_step(
|
||||
output_work_item): # an extra reference for ensure_backward
|
||||
output_work_item
|
||||
): # an extra reference for ensure_backward
|
||||
lifecycle += 1
|
||||
else:
|
||||
lifecycle = 0
|
||||
@@ -234,9 +232,9 @@ class WorkerBase(ABC):
|
||||
# offset supports get partial output to reduce comm costs.
|
||||
def get_output_by_key(self, key: UniqueKey, ref_use=False, rank=None, offsets=None) -> Any:
|
||||
output = self._get_output_all(key, ref_use, rank)
|
||||
if offsets is None: # get all for non iterable output
|
||||
if offsets is None: # get all for non iterable output
|
||||
return output
|
||||
else: # get part for iterable output
|
||||
else: # get part for iterable output
|
||||
output = [output[i] for i in offsets]
|
||||
return output
|
||||
|
||||
@@ -252,12 +250,12 @@ class WorkerBase(ABC):
|
||||
|
||||
def get_partition(self):
|
||||
with self.partition_condition_lock:
|
||||
self.partition_condition_lock.wait_for(lambda: hasattr(self, 'module_partition'))
|
||||
self.partition_condition_lock.wait_for(lambda: hasattr(self, "module_partition"))
|
||||
return self.module_partition
|
||||
|
||||
def get_partition_state_dict(self):
|
||||
with self.partition_condition_lock:
|
||||
self.partition_condition_lock.wait_for(lambda: hasattr(self, 'module_partition'))
|
||||
self.partition_condition_lock.wait_for(lambda: hasattr(self, "module_partition"))
|
||||
return self.module_partition.state_dict()
|
||||
|
||||
def _make_args_kwargs(self, microbatch, merge=False):
|
||||
@@ -293,8 +291,17 @@ class WorkerBase(ABC):
|
||||
# make args and kwargs
|
||||
args, kwargs = self._make_args_kwargs(microbatch)
|
||||
|
||||
work_item = WorkItem(self.pp_rank, Phase.FORWARD, args, kwargs, output, microbatch_id, None,
|
||||
self.num_microbatches, forward_only)
|
||||
work_item = WorkItem(
|
||||
self.pp_rank,
|
||||
Phase.FORWARD,
|
||||
args,
|
||||
kwargs,
|
||||
output,
|
||||
microbatch_id,
|
||||
None,
|
||||
self.num_microbatches,
|
||||
forward_only,
|
||||
)
|
||||
with self.work_list_condition_lock:
|
||||
self.work_list[key] = work_item
|
||||
self.work_list_condition_lock.notify_all()
|
||||
@@ -314,15 +321,25 @@ class WorkerBase(ABC):
|
||||
for off in self_input_offsets:
|
||||
self_arg_lst.append(arg_lst[off])
|
||||
|
||||
work_item = WorkItem(self.pp_rank, Phase.FORWARD, self_arg_lst, {}, output, microbatch_id, None,
|
||||
self.num_microbatches, forward_only)
|
||||
work_item = WorkItem(
|
||||
self.pp_rank,
|
||||
Phase.FORWARD,
|
||||
self_arg_lst,
|
||||
{},
|
||||
output,
|
||||
microbatch_id,
|
||||
None,
|
||||
self.num_microbatches,
|
||||
forward_only,
|
||||
)
|
||||
with self.work_list_condition_lock:
|
||||
self.work_list[key] = work_item
|
||||
self.work_list_condition_lock.notify_all()
|
||||
|
||||
# put input tensor which other nodes need into output_list as Phase.INPUT
|
||||
work_item_remote = WorkItem(self.pp_rank, Phase.INPUT, [], {}, arg_lst, microbatch_id, None,
|
||||
self.num_microbatches, forward_only)
|
||||
work_item_remote = WorkItem(
|
||||
self.pp_rank, Phase.INPUT, [], {}, arg_lst, microbatch_id, None, self.num_microbatches, forward_only
|
||||
)
|
||||
|
||||
with self.output_list_condition_lock:
|
||||
self.output_list[recv_input_key] = work_item_remote
|
||||
@@ -343,8 +360,17 @@ class WorkerBase(ABC):
|
||||
output = self._get_future_by_device()
|
||||
grad_wrt_loss = None
|
||||
|
||||
work_item = WorkItem(self.pp_rank, Phase.BACKWARD, grad_wrt_loss, {}, output, microbatch_id, None,
|
||||
self.num_microbatches, False)
|
||||
work_item = WorkItem(
|
||||
self.pp_rank,
|
||||
Phase.BACKWARD,
|
||||
grad_wrt_loss,
|
||||
{},
|
||||
output,
|
||||
microbatch_id,
|
||||
None,
|
||||
self.num_microbatches,
|
||||
False,
|
||||
)
|
||||
|
||||
self.work_list[key] = work_item
|
||||
self.work_list_condition_lock.notify_all()
|
||||
@@ -367,7 +393,7 @@ class WorkerBase(ABC):
|
||||
producer_stage_ids = self.get_producer_stage_ids()
|
||||
producer_num = len(producer_stage_ids)
|
||||
if self.need_model_input():
|
||||
producer_num += 1 # for input partition
|
||||
producer_num += 1 # for input partition
|
||||
subscribe_forward_futures: List[Future] = [None] * producer_num
|
||||
|
||||
# TODO(jiangziyue) get single value instead of the whole output
|
||||
@@ -376,9 +402,9 @@ class WorkerBase(ABC):
|
||||
producer_output_key = UniqueKey(microbatch_id, Phase.INPUT)
|
||||
producer_worker_rref = self.pp_rank_to_worker_rref[producer_stage_id]
|
||||
offsets = self._get_input_offsets_by_index(target_index=0)
|
||||
subscribe_forward_futures[0] = producer_worker_rref.rpc_async().get_output_by_key(producer_output_key,
|
||||
rank=self.pp_rank,
|
||||
offsets=offsets)
|
||||
subscribe_forward_futures[0] = producer_worker_rref.rpc_async().get_output_by_key(
|
||||
producer_output_key, rank=self.pp_rank, offsets=offsets
|
||||
)
|
||||
|
||||
for i in range(0, producer_num - 1):
|
||||
producer_stage_id = producer_stage_ids[i]
|
||||
@@ -386,11 +412,12 @@ class WorkerBase(ABC):
|
||||
producer_worker_rref = self.pp_rank_to_worker_rref[producer_stage_id]
|
||||
target_index = i + 1
|
||||
offsets = self._get_input_offsets_by_index(target_index=target_index)
|
||||
if offsets is not None and len(offsets) == 0: # no need to do rpc
|
||||
if offsets is not None and len(offsets) == 0: # no need to do rpc
|
||||
subscribe_forward_futures[target_index] = []
|
||||
else:
|
||||
subscribe_forward_futures[target_index] = producer_worker_rref.rpc_async().get_output_by_key(
|
||||
producer_output_key, rank=self.pp_rank, offsets=offsets)
|
||||
producer_output_key, rank=self.pp_rank, offsets=offsets
|
||||
)
|
||||
|
||||
else:
|
||||
for i in range(producer_num):
|
||||
@@ -399,14 +426,24 @@ class WorkerBase(ABC):
|
||||
producer_worker_rref = self.pp_rank_to_worker_rref[producer_stage_id]
|
||||
target_index = i
|
||||
offsets = self._get_input_offsets_by_index(target_index=target_index)
|
||||
if offsets is not None and len(offsets) == 0: # no need to do rpc
|
||||
if offsets is not None and len(offsets) == 0: # no need to do rpc
|
||||
subscribe_forward_futures[target_index] = []
|
||||
else:
|
||||
subscribe_forward_futures[target_index] = producer_worker_rref.rpc_async().get_output_by_key(
|
||||
producer_output_key, rank=self.pp_rank, offsets=offsets)
|
||||
producer_output_key, rank=self.pp_rank, offsets=offsets
|
||||
)
|
||||
|
||||
work_item_from_producer = WorkItem(stage_id, Phase.FORWARD, subscribe_forward_futures, {}, output,
|
||||
microbatch_id, None, self.num_microbatches, forward_only)
|
||||
work_item_from_producer = WorkItem(
|
||||
stage_id,
|
||||
Phase.FORWARD,
|
||||
subscribe_forward_futures,
|
||||
{},
|
||||
output,
|
||||
microbatch_id,
|
||||
None,
|
||||
self.num_microbatches,
|
||||
forward_only,
|
||||
)
|
||||
|
||||
return work_item_from_producer
|
||||
|
||||
@@ -441,15 +478,25 @@ class WorkerBase(ABC):
|
||||
consumer_worker_rref = self.pp_rank_to_worker_rref[consumer_stage_id]
|
||||
target_index = i
|
||||
offsets = self._get_output_offsets_by_index(target_index=target_index)
|
||||
if offsets is not None and len(offsets) == 0: # no need to do rpc
|
||||
if offsets is not None and len(offsets) == 0: # no need to do rpc
|
||||
subscribe_backward_futures[target_index] = []
|
||||
else:
|
||||
subscribe_backward_futures[target_index] = consumer_worker_rref.rpc_async().get_output_by_key(
|
||||
consumer_output_key, rank=self.pp_rank, offsets=offsets)
|
||||
consumer_output_key, rank=self.pp_rank, offsets=offsets
|
||||
)
|
||||
|
||||
# flatten args
|
||||
work_item_from_consumer = WorkItem(stage_id, Phase.BACKWARD, subscribe_backward_futures, {}, output,
|
||||
microbatch_id, None, self.num_microbatches, False)
|
||||
work_item_from_consumer = WorkItem(
|
||||
stage_id,
|
||||
Phase.BACKWARD,
|
||||
subscribe_backward_futures,
|
||||
{},
|
||||
output,
|
||||
microbatch_id,
|
||||
None,
|
||||
self.num_microbatches,
|
||||
False,
|
||||
)
|
||||
|
||||
return work_item_from_consumer
|
||||
|
||||
@@ -524,8 +571,8 @@ class WorkerBase(ABC):
|
||||
|
||||
def get_topo(self):
|
||||
with self.partition_condition_lock:
|
||||
self.partition_condition_lock.wait_for(lambda: hasattr(self, 'module_partition'))
|
||||
if hasattr(self.module_partition, '_topo'):
|
||||
self.partition_condition_lock.wait_for(lambda: hasattr(self, "module_partition"))
|
||||
if hasattr(self.module_partition, "_topo"):
|
||||
return self.module_partition._topo
|
||||
else:
|
||||
return None
|
||||
@@ -564,12 +611,12 @@ class WorkerBase(ABC):
|
||||
if stage_id == src_stage_id:
|
||||
src_index += i
|
||||
break
|
||||
else: # data from input partition
|
||||
else: # data from input partition
|
||||
src_index = 0
|
||||
# when output_len = 1, not iterable
|
||||
if target_index == src_index:
|
||||
if output_len == 1:
|
||||
res = None # offset = None to get all outputs
|
||||
res = None # offset = None to get all outputs
|
||||
return res
|
||||
else:
|
||||
res.append(src_offset)
|
||||
@@ -584,7 +631,6 @@ class WorkerBase(ABC):
|
||||
consumer_stage_ids = self.get_consumer_stage_ids()
|
||||
for val_list in output_vals:
|
||||
# An output may be passed to many down stages.
|
||||
target = None
|
||||
for val_pos in val_list.get():
|
||||
dst_partition_id = val_pos.partition_id
|
||||
dst_offset = val_pos.offset
|
||||
@@ -597,7 +643,7 @@ class WorkerBase(ABC):
|
||||
break
|
||||
if target_index == dst_index:
|
||||
if input_len == 1:
|
||||
res = None # offset = None to get all outputs
|
||||
res = None # offset = None to get all outputs
|
||||
return res
|
||||
else:
|
||||
res.append(dst_offset)
|
||||
@@ -623,7 +669,7 @@ class WorkerBase(ABC):
|
||||
flatten_args = []
|
||||
if self.is_first_stage():
|
||||
pytree_map(args_or_kwargs, fn=lambda x: flatten_args.append(x), map_all=True)
|
||||
else: # get by offset
|
||||
else: # get by offset
|
||||
topo: Topo = self.get_topo()
|
||||
self_partition_id = self.pp_rank_to_partition_id(self.pp_rank, topo)
|
||||
self_partition: Partition = topo.get_partition_by_id(self_partition_id)
|
||||
@@ -652,7 +698,7 @@ class WorkerBase(ABC):
|
||||
if stage_id == src_stage_id:
|
||||
src_index += i
|
||||
break
|
||||
else: # data from input partition
|
||||
else: # data from input partition
|
||||
src_index = 0
|
||||
# when output_len = 1, not iterable
|
||||
if output_len == 1:
|
||||
@@ -679,7 +725,7 @@ class WorkerBase(ABC):
|
||||
else:
|
||||
for i, arg in enumerate(args_or_kwargs):
|
||||
args_or_kwargs[i] = arg.wait()
|
||||
if args_or_kwargs is not None: # get by offset
|
||||
if args_or_kwargs is not None: # get by offset
|
||||
flatten_args = []
|
||||
topo: Topo = self.get_topo()
|
||||
self_partition_id = self.pp_rank_to_partition_id(self.pp_rank, topo)
|
||||
@@ -719,7 +765,7 @@ class WorkerBase(ABC):
|
||||
@abstractmethod
|
||||
def _get_work_item_key(self) -> UniqueKey:
|
||||
"""
|
||||
this method control the order of the microbatch to consume
|
||||
this method control the order of the microbatch to consume
|
||||
"""
|
||||
|
||||
def is_first_stage(self):
|
||||
@@ -761,7 +807,7 @@ class WorkerBase(ABC):
|
||||
kwargs = work_item.kwargs
|
||||
microbatch_id = work_item.microbatch_id
|
||||
forward_only = work_item.forward_only
|
||||
data_process_func = getattr(self, 'data_process_func', self._default_data_process_func)
|
||||
data_process_func = getattr(self, "data_process_func", self._default_data_process_func)
|
||||
consume_result = None
|
||||
|
||||
is_first_stage = self.is_first_stage()
|
||||
@@ -787,10 +833,12 @@ class WorkerBase(ABC):
|
||||
else:
|
||||
args_kwargs = self._get_real_args_kwargs_fwd(args)
|
||||
|
||||
args_kwargs = pyobj_map(args_kwargs, fn=lambda x: x.to(self.device).detach(),
|
||||
process_types=torch.Tensor) # torch rpc doesn't support args or rets in GPU
|
||||
args_kwargs = pyobj_map(args_kwargs, fn=lambda x: self.device,
|
||||
process_types=torch.device) # change devices from last stage to current device
|
||||
args_kwargs = pyobj_map(
|
||||
args_kwargs, fn=lambda x: x.to(self.device).detach(), process_types=torch.Tensor
|
||||
) # torch rpc doesn't support args or rets in GPU
|
||||
args_kwargs = pyobj_map(
|
||||
args_kwargs, fn=lambda x: self.device, process_types=torch.device
|
||||
) # change devices from last stage to current device
|
||||
|
||||
args, kwargs = data_process_func(args_kwargs)
|
||||
|
||||
@@ -851,16 +899,16 @@ class WorkerBase(ABC):
|
||||
use_checkpoint = False
|
||||
|
||||
if not forward_only:
|
||||
self.microbatch_id_to_backward_cache[microbatch_id] = BackwardCache(stage_input_args,
|
||||
stage_input_kwargs,
|
||||
stage_outputs,
|
||||
checkpoint=use_checkpoint)
|
||||
consume_result = pyobj_map(consume_result, fn=lambda x: x.to('cpu'),
|
||||
process_types=torch.Tensor) # torch rpc doesn't support args or rets in
|
||||
self.microbatch_id_to_backward_cache[microbatch_id] = BackwardCache(
|
||||
stage_input_args, stage_input_kwargs, stage_outputs, checkpoint=use_checkpoint
|
||||
)
|
||||
consume_result = pyobj_map(
|
||||
consume_result, fn=lambda x: x.to("cpu"), process_types=torch.Tensor
|
||||
) # torch rpc doesn't support args or rets in
|
||||
|
||||
# if not forward_only, do the backward
|
||||
if not forward_only:
|
||||
if is_last_stage: # if it is the last stage, trigger backward automatic
|
||||
if is_last_stage: # if it is the last stage, trigger backward automatic
|
||||
self._begin_backward(microbatch_id)
|
||||
|
||||
elif phase == Phase.BACKWARD:
|
||||
@@ -872,7 +920,9 @@ class WorkerBase(ABC):
|
||||
self.backward_times += 1
|
||||
self.outstanding -= 1
|
||||
|
||||
assert microbatch_id in self.microbatch_id_to_backward_cache, f"microbatch_id {microbatch_id} not in backward cache"
|
||||
assert (
|
||||
microbatch_id in self.microbatch_id_to_backward_cache
|
||||
), f"microbatch_id {microbatch_id} not in backward cache"
|
||||
backward_cache = self.microbatch_id_to_backward_cache.pop(microbatch_id)
|
||||
|
||||
stage_outputs = backward_cache.stage_outputs
|
||||
@@ -906,8 +956,9 @@ class WorkerBase(ABC):
|
||||
filtered_grads.append(grad)
|
||||
|
||||
stage_outputs = filtered_outputs
|
||||
grad_tensors = pyobj_map(filtered_grads, fn=lambda x: x.to(self.device),
|
||||
process_types=torch.Tensor) # torch rpc doesn't support args or rets in GPU
|
||||
grad_tensors = pyobj_map(
|
||||
filtered_grads, fn=lambda x: x.to(self.device), process_types=torch.Tensor
|
||||
) # torch rpc doesn't support args or rets in GPU
|
||||
autograd.backward(stage_outputs, grad_tensors=grad_tensors)
|
||||
|
||||
# collect grad of input tensor
|
||||
@@ -920,8 +971,8 @@ class WorkerBase(ABC):
|
||||
else:
|
||||
consume_result.append(None)
|
||||
consume_result = pyobj_map(
|
||||
consume_result, fn=lambda x: x.to('cpu'),
|
||||
process_types=torch.Tensor) # torch rpc doesn't support args or rets in GPU
|
||||
consume_result, fn=lambda x: x.to("cpu"), process_types=torch.Tensor
|
||||
) # torch rpc doesn't support args or rets in GPU
|
||||
|
||||
else:
|
||||
raise TypeError(f"Unknown phase appears in _consume_work_item_by_phase {phase}")
|
||||
@@ -929,7 +980,7 @@ class WorkerBase(ABC):
|
||||
return consume_result
|
||||
|
||||
def _get_store_len(self):
|
||||
return f'work_list:{len(self.work_list)} output_list:{len(self.output_list)} backward_cache:{len(self.microbatch_id_to_backward_cache)} label_cache:{len(self.microbatch_id_to_labels)}'
|
||||
return f"work_list:{len(self.work_list)} output_list:{len(self.output_list)} backward_cache:{len(self.microbatch_id_to_backward_cache)} label_cache:{len(self.microbatch_id_to_labels)}"
|
||||
|
||||
def _get_parameter_grad_sum(self):
|
||||
grad_sum = 0
|
||||
@@ -1014,19 +1065,20 @@ class WorkerBase(ABC):
|
||||
|
||||
|
||||
class PipelineEngineBase(ABC, nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
worker_type,
|
||||
partition_fn: Callable,
|
||||
stage_num,
|
||||
num_microbatches,
|
||||
device: str,
|
||||
use_1F1B=False,
|
||||
chunk: int = 1,
|
||||
criterion: Callable = None,
|
||||
metric: Callable = None,
|
||||
checkpoint: bool = False,
|
||||
data_process_func: Callable = None) -> None:
|
||||
def __init__(
|
||||
self,
|
||||
worker_type,
|
||||
partition_fn: Callable,
|
||||
stage_num,
|
||||
num_microbatches,
|
||||
device: str,
|
||||
use_1F1B=False,
|
||||
chunk: int = 1,
|
||||
criterion: Callable = None,
|
||||
metric: Callable = None,
|
||||
checkpoint: bool = False,
|
||||
data_process_func: Callable = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.worker_type = worker_type
|
||||
self.partition_fn: Callable = partition_fn
|
||||
@@ -1056,12 +1108,12 @@ class PipelineEngineBase(ABC, nn.Module):
|
||||
data_process_func = self.data_process_func
|
||||
if data_process_func is not None:
|
||||
assert callable(data_process_func), "data_process_func must be a function"
|
||||
assert '<locals>' not in data_process_func.__repr__(), "data_process_func must be a global function"
|
||||
assert '<lambda>' not in data_process_func.__repr__(), "data_process_func cannot be a lambda expression"
|
||||
assert "<locals>" not in data_process_func.__repr__(), "data_process_func must be a global function"
|
||||
assert "<lambda>" not in data_process_func.__repr__(), "data_process_func cannot be a lambda expression"
|
||||
sig = inspect.signature(data_process_func)
|
||||
assert len(
|
||||
sig.parameters
|
||||
) == 2, f"length of data_process_func' arguments must be 2, receive {len(sig.parameters)} arguments instead"
|
||||
assert (
|
||||
len(sig.parameters) == 2
|
||||
), f"length of data_process_func' arguments must be 2, receive {len(sig.parameters)} arguments instead"
|
||||
|
||||
def _get_actual_stage_num(self) -> int:
|
||||
return self.stage_num if self.chunk == 1 else self.virtual_stage_num
|
||||
@@ -1104,19 +1156,33 @@ class PipelineEngineBase(ABC, nn.Module):
|
||||
partition_id = self.pp_rank_to_module_partition_id[pp_rank]
|
||||
partition_args = (partition_id, chunk, actual_stage_num)
|
||||
rpc_worker_id = self.pp_rank_to_rpc_worker_id[pp_rank]
|
||||
if device[:4] == 'cuda':
|
||||
device = f'cuda:{rpc_worker_id}'
|
||||
self.pp_rank_to_worker_rref[pp_rank] = rpc.remote(rpc_worker_id,
|
||||
worker_type,
|
||||
args=(partition_fn, partition_args, pp_rank,
|
||||
actual_stage_num, num_microbatches, device,
|
||||
criterion, metric, checkpoint, data_process_func))
|
||||
if device[:4] == "cuda":
|
||||
device = f"cuda:{rpc_worker_id}"
|
||||
self.pp_rank_to_worker_rref[pp_rank] = rpc.remote(
|
||||
rpc_worker_id,
|
||||
worker_type,
|
||||
args=(
|
||||
partition_fn,
|
||||
partition_args,
|
||||
pp_rank,
|
||||
actual_stage_num,
|
||||
num_microbatches,
|
||||
device,
|
||||
criterion,
|
||||
metric,
|
||||
checkpoint,
|
||||
data_process_func,
|
||||
),
|
||||
)
|
||||
|
||||
# let each worker know global worker rref (include itself)
|
||||
sync_futs = []
|
||||
for pp_rank in self.pp_rank_to_worker_rref:
|
||||
fut = self.pp_rank_to_worker_rref[pp_rank].rpc_async(timeout=0).sync_global_worker_rrefs(
|
||||
self.pp_rank_to_worker_rref)
|
||||
fut = (
|
||||
self.pp_rank_to_worker_rref[pp_rank]
|
||||
.rpc_async(timeout=0)
|
||||
.sync_global_worker_rrefs(self.pp_rank_to_worker_rref)
|
||||
)
|
||||
sync_futs.append(fut)
|
||||
|
||||
for fut in sync_futs:
|
||||
@@ -1157,8 +1223,9 @@ class PipelineEngineBase(ABC, nn.Module):
|
||||
def get_output_pp_ranks(self) -> List[int]:
|
||||
return [self._get_actual_stage_num() - 1]
|
||||
|
||||
def _consume_constraint(self, microbatch_id: int, forward_only: bool, input_pp_ranks: List[int],
|
||||
output_pp_ranks: List[int], ret_future):
|
||||
def _consume_constraint(
|
||||
self, microbatch_id: int, forward_only: bool, input_pp_ranks: List[int], output_pp_ranks: List[int], ret_future
|
||||
):
|
||||
actual_stage_num = self._get_actual_stage_num()
|
||||
use_1F1B = self.use_1F1B
|
||||
if microbatch_id >= actual_stage_num:
|
||||
@@ -1206,7 +1273,8 @@ class PipelineEngineBase(ABC, nn.Module):
|
||||
worker_rref = self.pp_rank_to_worker_rref[pp_rank]
|
||||
key = UniqueKey(self.num_microbatches - 1, Phase.BACKWARD)
|
||||
fut = worker_rref.rpc_async().get_output_by_key(
|
||||
key, offsets=[]) # only ensure the res exists, no need for real data.
|
||||
key, offsets=[]
|
||||
) # only ensure the res exists, no need for real data.
|
||||
backward_result.append(fut)
|
||||
|
||||
for fut in backward_result:
|
||||
@@ -1244,11 +1312,14 @@ class PipelineEngineBase(ABC, nn.Module):
|
||||
|
||||
if labels is not None and not forward_only:
|
||||
assert hasattr(
|
||||
self, 'optimizer_class'), "call `initialize_optimizer` to initialize optimizer before forward_backward"
|
||||
self, "optimizer_class"
|
||||
), "call `initialize_optimizer` to initialize optimizer before forward_backward"
|
||||
|
||||
num_microbatches = self.num_microbatches
|
||||
|
||||
assert batch_length >= num_microbatches, "num_microbatches is greater than the size of a batch, which is illegal"
|
||||
assert (
|
||||
batch_length >= num_microbatches
|
||||
), "num_microbatches is greater than the size of a batch, which is illegal"
|
||||
microbatch_size = math.ceil(batch_length / num_microbatches)
|
||||
device = self.device
|
||||
|
||||
@@ -1285,10 +1356,10 @@ class PipelineEngineBase(ABC, nn.Module):
|
||||
# collect forward result
|
||||
forward_result = self._collect_forward_result(output_pp_ranks, ret_future)
|
||||
|
||||
if not forward_only and hasattr(self, 'optimizer_class'):
|
||||
if not forward_only and hasattr(self, "optimizer_class"):
|
||||
self.step()
|
||||
|
||||
self._reset_worker() # reset worker attributes for next batch
|
||||
self._reset_worker() # reset worker attributes for next batch
|
||||
return forward_result
|
||||
|
||||
def initialize_optimizer(self, optimizer_class: type, **kwargs):
|
||||
|
Reference in New Issue
Block a user