mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-04 02:26:51 +00:00
[misc] update pre-commit and run all files (#4752)
* [misc] update pre-commit * [misc] run pre-commit * [misc] remove useless configuration files * [misc] ignore cuda for clang-format
This commit is contained in:
@@ -1,4 +1,4 @@
|
||||
from ._base_engine import Engine
|
||||
from .gradient_handler import *
|
||||
|
||||
__all__ = ['Engine']
|
||||
__all__ = ["Engine"]
|
||||
|
@@ -59,15 +59,17 @@ class Engine:
|
||||
`Run resnet cifar10 with engine <https://github.com/hpcaitech/ColossalAI-Examples/blob/main/image/resnet/run_resnet_cifar10_with_engine.py>`_.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
model: Module,
|
||||
optimizer: "OptimizerWrapper",
|
||||
criterion: Optional[_Loss] = None,
|
||||
gradient_handlers: Optional[List[BaseGradientHandler]] = None,
|
||||
clip_grad_norm: float = 0.0,
|
||||
ophook_list: Optional[List[BaseOpHook]] = None,
|
||||
verbose: bool = True,
|
||||
schedule: Optional[BaseSchedule] = None):
|
||||
def __init__(
|
||||
self,
|
||||
model: Module,
|
||||
optimizer: "OptimizerWrapper",
|
||||
criterion: Optional[_Loss] = None,
|
||||
gradient_handlers: Optional[List[BaseGradientHandler]] = None,
|
||||
clip_grad_norm: float = 0.0,
|
||||
ophook_list: Optional[List[BaseOpHook]] = None,
|
||||
verbose: bool = True,
|
||||
schedule: Optional[BaseSchedule] = None,
|
||||
):
|
||||
self._model = model
|
||||
self._optimizer = optimizer
|
||||
self._criterion = criterion
|
||||
@@ -76,7 +78,7 @@ class Engine:
|
||||
self._logger = get_dist_logger()
|
||||
|
||||
# state
|
||||
self.training = True # default
|
||||
self.training = True # default
|
||||
|
||||
# build gradient handler
|
||||
if gradient_handlers:
|
||||
@@ -91,8 +93,9 @@ class Engine:
|
||||
|
||||
# build schedule
|
||||
if schedule:
|
||||
assert isinstance(schedule, BaseSchedule), \
|
||||
f'expected schedule to be of type BaseSchedule, but got {type(schedule)}'
|
||||
assert isinstance(
|
||||
schedule, BaseSchedule
|
||||
), f"expected schedule to be of type BaseSchedule, but got {type(schedule)}"
|
||||
self._schedule = schedule
|
||||
else:
|
||||
self._schedule = NonPipelineSchedule()
|
||||
@@ -149,13 +152,11 @@ class Engine:
|
||||
logger.warning(f"removing hooks is currently not supported")
|
||||
|
||||
def zero_grad(self):
|
||||
"""Set the gradient of parameters to zero
|
||||
"""
|
||||
"""Set the gradient of parameters to zero"""
|
||||
self.optimizer.zero_grad()
|
||||
|
||||
def step(self):
|
||||
"""Execute parameter update
|
||||
"""
|
||||
"""Execute parameter update"""
|
||||
self._all_reduce_gradients()
|
||||
self.optimizer.clip_grad_by_norm(self._clip_grad_norm)
|
||||
return self.optimizer.step()
|
||||
@@ -192,8 +193,7 @@ class Engine:
|
||||
return self.model(*args, **kwargs)
|
||||
|
||||
def _all_reduce_gradients(self):
|
||||
"""Handles all-reduce operations of gradients across different parallel groups.
|
||||
"""
|
||||
"""Handles all-reduce operations of gradients across different parallel groups."""
|
||||
for handler in self._gradient_handlers:
|
||||
handler.handle_gradient()
|
||||
|
||||
@@ -208,13 +208,11 @@ class Engine:
|
||||
return output, label, loss
|
||||
|
||||
def train(self):
|
||||
"""Sets the model to training mode.
|
||||
"""
|
||||
"""Sets the model to training mode."""
|
||||
self.training = True
|
||||
self._model.train()
|
||||
|
||||
def eval(self):
|
||||
"""Sets the model to evaluation mode.
|
||||
"""
|
||||
"""Sets the model to evaluation mode."""
|
||||
self.training = False
|
||||
self._model.eval()
|
||||
|
@@ -14,17 +14,22 @@ from ._gradient_accumulation import (
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
'accumulate_gradient', 'GradAccumDataloader', 'GradAccumOptimizer', 'GradAccumLrSchedulerByStep',
|
||||
'GradAccumGradientHandler'
|
||||
"accumulate_gradient",
|
||||
"GradAccumDataloader",
|
||||
"GradAccumOptimizer",
|
||||
"GradAccumLrSchedulerByStep",
|
||||
"GradAccumGradientHandler",
|
||||
]
|
||||
|
||||
|
||||
def accumulate_gradient(model: nn.Module,
|
||||
optimizer: Optimizer,
|
||||
dataloader: Iterable,
|
||||
accumulate_size: int,
|
||||
gradient_handlers: List[BaseGradientHandler] = None,
|
||||
lr_scheduler: _LRScheduler = None):
|
||||
def accumulate_gradient(
|
||||
model: nn.Module,
|
||||
optimizer: Optimizer,
|
||||
dataloader: Iterable,
|
||||
accumulate_size: int,
|
||||
gradient_handlers: List[BaseGradientHandler] = None,
|
||||
lr_scheduler: _LRScheduler = None,
|
||||
):
|
||||
r"""Turning model, optimizer, dataloader into corresponding object for gradient accumulation.
|
||||
|
||||
Args:
|
||||
|
@@ -272,8 +272,9 @@ class GradAccumGradientHandler:
|
||||
"""
|
||||
|
||||
def __init__(self, grad_handler: BaseGradientHandler, accumulate_size: int) -> None:
|
||||
assert isinstance(grad_handler, BaseGradientHandler), \
|
||||
f'expected grad_handler to be type BaseGradientHandler, but got {type(grad_handler)}'
|
||||
assert isinstance(
|
||||
grad_handler, BaseGradientHandler
|
||||
), f"expected grad_handler to be type BaseGradientHandler, but got {type(grad_handler)}"
|
||||
self.grad_handler = grad_handler
|
||||
self.accumulate_size = accumulate_size
|
||||
self.accumulate_step = 0
|
||||
|
@@ -6,6 +6,10 @@ from ._sequence_parallel_gradient_handler import SequenceParallelGradientHandler
|
||||
from ._zero_gradient_handler import ZeROGradientHandler
|
||||
|
||||
__all__ = [
|
||||
'BaseGradientHandler', 'DataParallelGradientHandler', 'ZeROGradientHandler', 'PipelineSharedModuleGradientHandler',
|
||||
'MoeGradientHandler', 'SequenceParallelGradientHandler'
|
||||
"BaseGradientHandler",
|
||||
"DataParallelGradientHandler",
|
||||
"ZeROGradientHandler",
|
||||
"PipelineSharedModuleGradientHandler",
|
||||
"MoeGradientHandler",
|
||||
"SequenceParallelGradientHandler",
|
||||
]
|
||||
|
@@ -22,4 +22,3 @@ class BaseGradientHandler(ABC):
|
||||
"""A method to accumulate gradients across different parallel groups. Users should
|
||||
write their own functions or just use the functions in pre-defined subclasses.
|
||||
"""
|
||||
pass
|
||||
|
@@ -20,8 +20,7 @@ class DataParallelGradientHandler(BaseGradientHandler):
|
||||
"""
|
||||
|
||||
def handle_gradient(self):
|
||||
"""A method running a all-reduce operation in a data parallel group.
|
||||
"""
|
||||
"""A method running a all-reduce operation in a data parallel group."""
|
||||
# TODO: add memory buffer
|
||||
if gpc.data_parallel_size > 1:
|
||||
bucket_allreduce(param_list=self._model.parameters(), group=gpc.get_group(ParallelMode.DATA))
|
||||
|
@@ -42,5 +42,6 @@ class MoeGradientHandler(BaseGradientHandler):
|
||||
|
||||
for ep_size in epsize_param_dict:
|
||||
if ep_size != 1 and ep_size != MOE_CONTEXT.world_size:
|
||||
bucket_allreduce(param_list=epsize_param_dict[ep_size],
|
||||
group=MOE_CONTEXT.parallel_info_dict[ep_size].dp_group)
|
||||
bucket_allreduce(
|
||||
param_list=epsize_param_dict[ep_size], group=MOE_CONTEXT.parallel_info_dict[ep_size].dp_group
|
||||
)
|
||||
|
@@ -26,17 +26,21 @@ class PipelineSharedModuleGradientHandler(BaseGradientHandler):
|
||||
"""
|
||||
|
||||
def handle_gradient(self):
|
||||
"""A method running a all-reduce operation in sub pipeline parallel groups.
|
||||
"""
|
||||
"""A method running a all-reduce operation in sub pipeline parallel groups."""
|
||||
if gpc.pipeline_parallel_size > 1:
|
||||
# bucketize and all-reduce
|
||||
buckets = defaultdict(lambda: defaultdict(list))
|
||||
# Pack the buckets.
|
||||
for param in self._model.parameters():
|
||||
group = getattr(param, 'pipeline_shared_module_pg', None)
|
||||
if param.requires_grad and group is not None and (
|
||||
(hasattr(param, 'colo_attr') and not param.colo_attr.saved_grad.is_null())
|
||||
or param.grad is not None):
|
||||
group = getattr(param, "pipeline_shared_module_pg", None)
|
||||
if (
|
||||
param.requires_grad
|
||||
and group is not None
|
||||
and (
|
||||
(hasattr(param, "colo_attr") and not param.colo_attr.saved_grad.is_null())
|
||||
or param.grad is not None
|
||||
)
|
||||
):
|
||||
tp = param.data.type()
|
||||
buckets[group][tp].append(param)
|
||||
|
||||
@@ -44,7 +48,7 @@ class PipelineSharedModuleGradientHandler(BaseGradientHandler):
|
||||
for group, group_buckets in buckets.items():
|
||||
for tp, bucket in group_buckets.items():
|
||||
grads = [
|
||||
param.colo_attr.grad_payload if hasattr(param, 'colo_attr') else param.grad.data
|
||||
param.colo_attr.grad_payload if hasattr(param, "colo_attr") else param.grad.data
|
||||
for param in bucket
|
||||
]
|
||||
coalesced = _flatten_dense_tensors(grads).to(torch.cuda.current_device())
|
||||
|
@@ -20,7 +20,6 @@ class SequenceParallelGradientHandler(BaseGradientHandler):
|
||||
"""
|
||||
|
||||
def handle_gradient(self):
|
||||
"""A method running a all-reduce operation in a data parallel group.
|
||||
"""
|
||||
"""A method running a all-reduce operation in a data parallel group."""
|
||||
if gpc.get_world_size(ParallelMode.SEQUENCE_DP) > 1:
|
||||
bucket_allreduce(param_list=self._model.parameters(), group=gpc.get_group(ParallelMode.SEQUENCE_DP))
|
||||
|
@@ -16,6 +16,5 @@ class ZeROGradientHandler(BaseGradientHandler):
|
||||
"""
|
||||
|
||||
def handle_gradient(self):
|
||||
"""A method running a all-reduce operation in a data parallel group.
|
||||
"""
|
||||
"""A method running a all-reduce operation in a data parallel group."""
|
||||
self._optimizer.sync_grad()
|
||||
|
@@ -2,4 +2,4 @@ from ._base_schedule import BaseSchedule
|
||||
from ._non_pipeline_schedule import NonPipelineSchedule
|
||||
from ._pipeline_schedule import InterleavedPipelineSchedule, PipelineSchedule, get_tensor_shape
|
||||
|
||||
__all__ = ['BaseSchedule', 'NonPipelineSchedule', 'PipelineSchedule', 'InterleavedPipelineSchedule', 'get_tensor_shape']
|
||||
__all__ = ["BaseSchedule", "NonPipelineSchedule", "PipelineSchedule", "InterleavedPipelineSchedule", "get_tensor_shape"]
|
||||
|
@@ -47,7 +47,8 @@ class BaseSchedule(ABC):
|
||||
data = {k: self._move_tensor(v) for k, v in data.items()}
|
||||
else:
|
||||
raise TypeError(
|
||||
f"Expected batch data to be of type torch.Tensor, list, tuple, or dict, but got {type(data)}")
|
||||
f"Expected batch data to be of type torch.Tensor, list, tuple, or dict, but got {type(data)}"
|
||||
)
|
||||
return data
|
||||
|
||||
def _get_batch_size(self, data):
|
||||
@@ -72,7 +73,7 @@ class BaseSchedule(ABC):
|
||||
Tuple (:class:`Tensor`, :class:`torch.Tensor`): A tuple of (data, label).
|
||||
"""
|
||||
if data_iter is None:
|
||||
raise RuntimeError('Dataloader is not defined.')
|
||||
raise RuntimeError("Dataloader is not defined.")
|
||||
batch_data = next(data_iter)
|
||||
|
||||
if to_gpu:
|
||||
@@ -81,17 +82,17 @@ class BaseSchedule(ABC):
|
||||
return batch_data
|
||||
|
||||
def pre_processing(self, engine):
|
||||
"""To perform actions before running the schedule.
|
||||
"""
|
||||
pass
|
||||
"""To perform actions before running the schedule."""
|
||||
|
||||
@abstractmethod
|
||||
def forward_backward_step(self,
|
||||
engine,
|
||||
data_iter: Iterable,
|
||||
forward_only: bool,
|
||||
return_loss: bool = True,
|
||||
return_output_label: bool = True):
|
||||
def forward_backward_step(
|
||||
self,
|
||||
engine,
|
||||
data_iter: Iterable,
|
||||
forward_only: bool,
|
||||
return_loss: bool = True,
|
||||
return_output_label: bool = True,
|
||||
):
|
||||
"""The process function over a batch of dataset for training or evaluation.
|
||||
|
||||
Args:
|
||||
@@ -101,7 +102,6 @@ class BaseSchedule(ABC):
|
||||
return_loss (bool, optional): If False, the loss won't be returned.
|
||||
return_output_label (bool, optional): If False, the output and label won't be returned.
|
||||
"""
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
def _call_engine(engine, inputs):
|
||||
@@ -113,13 +113,14 @@ class BaseSchedule(ABC):
|
||||
return engine(**inputs)
|
||||
else:
|
||||
TypeError(
|
||||
f"Expected engine inputs to be of type torch.Tensor, list, tuple, or dict, but got {type(inputs)}")
|
||||
f"Expected engine inputs to be of type torch.Tensor, list, tuple, or dict, but got {type(inputs)}"
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _call_engine_criterion(engine, outputs, labels):
|
||||
assert isinstance(outputs,
|
||||
(torch.Tensor, list, tuple,
|
||||
dict)), f'Expect output of model is (torch.Tensor, list, tuple), got {type(outputs)}'
|
||||
assert isinstance(
|
||||
outputs, (torch.Tensor, list, tuple, dict)
|
||||
), f"Expect output of model is (torch.Tensor, list, tuple), got {type(outputs)}"
|
||||
if isinstance(outputs, torch.Tensor):
|
||||
outputs = (outputs,)
|
||||
if isinstance(labels, torch.Tensor):
|
||||
@@ -134,6 +135,8 @@ class BaseSchedule(ABC):
|
||||
elif isinstance(outputs, dict) and isinstance(labels, (list, tuple)):
|
||||
raise ValueError(f"Expected labels to be a dict when the model outputs are dict, but got {type(labels)}")
|
||||
else:
|
||||
raise TypeError(f"Expected model outputs and labels to be of type torch.Tensor ' \
|
||||
raise TypeError(
|
||||
f"Expected model outputs and labels to be of type torch.Tensor ' \
|
||||
'(which is auto-converted to tuple), list, tuple, or dict, ' \
|
||||
'but got {type(outputs)} (model outputs) and {type(labels)} (labels)")
|
||||
'but got {type(outputs)} (model outputs) and {type(labels)} (labels)"
|
||||
)
|
||||
|
@@ -37,19 +37,22 @@ class NonPipelineSchedule(BaseSchedule):
|
||||
|
||||
if data_process_func:
|
||||
sig = inspect.signature(data_process_func)
|
||||
assert len(sig.parameters) == 1, \
|
||||
'The data_process_func only takes in one parameter for NonPipelineSchedule, ' \
|
||||
'which is a tuple of tensors for the current batch, ' \
|
||||
'i.e. data_process_func(dataloader_output).'
|
||||
assert len(sig.parameters) == 1, (
|
||||
"The data_process_func only takes in one parameter for NonPipelineSchedule, "
|
||||
"which is a tuple of tensors for the current batch, "
|
||||
"i.e. data_process_func(dataloader_output)."
|
||||
)
|
||||
|
||||
super().__init__(data_process_func)
|
||||
|
||||
def forward_backward_step(self,
|
||||
engine,
|
||||
data_iter: Iterable,
|
||||
forward_only: bool = False,
|
||||
return_loss: bool = True,
|
||||
return_output_label: bool = True):
|
||||
def forward_backward_step(
|
||||
self,
|
||||
engine,
|
||||
data_iter: Iterable,
|
||||
forward_only: bool = False,
|
||||
return_loss: bool = True,
|
||||
return_output_label: bool = True,
|
||||
):
|
||||
"""The process function that loads a batch of dataset and feeds it to the model.
|
||||
The returned labels and loss will None if :attr:`return_loss` is False.
|
||||
|
||||
@@ -64,8 +67,9 @@ class NonPipelineSchedule(BaseSchedule):
|
||||
Returns:
|
||||
Tuple[:class:`torch.Tensor`]: A tuple of (output, label, loss), loss and label could be None.
|
||||
"""
|
||||
assert forward_only or return_loss, \
|
||||
"The argument 'return_loss' has to be True when 'forward_only' is False, but got False."
|
||||
assert (
|
||||
forward_only or return_loss
|
||||
), "The argument 'return_loss' has to be True when 'forward_only' is False, but got False."
|
||||
batch_data = self.load_batch(data_iter)
|
||||
if self.data_process_func:
|
||||
data, label = self.data_process_func(batch_data)
|
||||
|
@@ -18,14 +18,18 @@ from ._base_schedule import BaseSchedule
|
||||
|
||||
|
||||
def get_tensor_shape():
|
||||
if hasattr(gpc.config, 'TENSOR_SHAPE'):
|
||||
if hasattr(gpc.config, "TENSOR_SHAPE"):
|
||||
return gpc.config.TENSOR_SHAPE
|
||||
|
||||
if not gpc.is_initialized(ParallelMode.PIPELINE):
|
||||
return None
|
||||
|
||||
if hasattr(gpc.config, 'SEQ_LENGTH') and hasattr(gpc.config, 'GLOBAL_BATCH_SIZE') and hasattr(
|
||||
gpc.config, 'GLOBAL_BATCH_SIZE') and hasattr(gpc.config, 'HIDDEN_SIZE'):
|
||||
if (
|
||||
hasattr(gpc.config, "SEQ_LENGTH")
|
||||
and hasattr(gpc.config, "GLOBAL_BATCH_SIZE")
|
||||
and hasattr(gpc.config, "GLOBAL_BATCH_SIZE")
|
||||
and hasattr(gpc.config, "HIDDEN_SIZE")
|
||||
):
|
||||
if gpc.is_initialized(ParallelMode.DATA):
|
||||
dp_size = gpc.get_world_size(ParallelMode.DATA)
|
||||
else:
|
||||
@@ -35,8 +39,11 @@ def get_tensor_shape():
|
||||
else:
|
||||
seq_size = 1
|
||||
|
||||
tensor_shape = (gpc.config.SEQ_LENGTH // seq_size,
|
||||
gpc.config.GLOBAL_BATCH_SIZE // dp_size // gpc.config.NUM_MICRO_BATCHES, gpc.config.HIDDEN_SIZE)
|
||||
tensor_shape = (
|
||||
gpc.config.SEQ_LENGTH // seq_size,
|
||||
gpc.config.GLOBAL_BATCH_SIZE // dp_size // gpc.config.NUM_MICRO_BATCHES,
|
||||
gpc.config.HIDDEN_SIZE,
|
||||
)
|
||||
return tensor_shape
|
||||
else:
|
||||
return None
|
||||
@@ -49,7 +56,7 @@ def pack_return_tensors(return_tensors):
|
||||
elif isinstance(output[0], (list, tuple)):
|
||||
output = tuple(torch.cat(tensors, dim=0) for tensors in zip(*output))
|
||||
else:
|
||||
raise TypeError(f'Output of model must be tensor or list/tuple of tensors')
|
||||
raise TypeError(f"Output of model must be tensor or list/tuple of tensors")
|
||||
if isinstance(label[0], torch.Tensor):
|
||||
label = torch.cat(label, dim=0)
|
||||
else:
|
||||
@@ -88,28 +95,31 @@ class PipelineSchedule(BaseSchedule):
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
num_microbatches,
|
||||
data_process_func: Callable = None,
|
||||
tensor_shape: Union[torch.Size, List[int], Tuple[int]] = None,
|
||||
scatter_gather_tensors: bool = False):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_microbatches,
|
||||
data_process_func: Callable = None,
|
||||
tensor_shape: Union[torch.Size, List[int], Tuple[int]] = None,
|
||||
scatter_gather_tensors: bool = False,
|
||||
):
|
||||
# we need to make sure that the signature of the data_process_func is valid
|
||||
if data_process_func:
|
||||
sig = inspect.signature(data_process_func)
|
||||
assert len(sig.parameters) == 2, \
|
||||
'The data_process_func only takes in two parameters for NonPipelineSchedule, ' \
|
||||
'which is the tensors passed by the previous pipeline stage and the dataloader output from this stage, ' \
|
||||
'i.e. data_process_func(stage_output, dataloader_output).'
|
||||
assert len(sig.parameters) == 2, (
|
||||
"The data_process_func only takes in two parameters for NonPipelineSchedule, "
|
||||
"which is the tensors passed by the previous pipeline stage and the dataloader output from this stage, "
|
||||
"i.e. data_process_func(stage_output, dataloader_output)."
|
||||
)
|
||||
|
||||
super().__init__(data_process_func=data_process_func)
|
||||
|
||||
assert num_microbatches > 0, f'expected num_microbatches to be larger then 1, but got {num_microbatches}'
|
||||
assert num_microbatches > 0, f"expected num_microbatches to be larger then 1, but got {num_microbatches}"
|
||||
|
||||
self.num_microbatches = num_microbatches
|
||||
self.dtype = torch.float
|
||||
assert not isinstance(tensor_shape,
|
||||
int), "tensor_shape type should be one of Union[torch.Size, List[int], Tuple[int]]."
|
||||
assert not isinstance(
|
||||
tensor_shape, int
|
||||
), "tensor_shape type should be one of Union[torch.Size, List[int], Tuple[int]]."
|
||||
if tensor_shape is None:
|
||||
self.tensor_shape = tensor_shape
|
||||
elif isinstance(tensor_shape, torch.Size):
|
||||
@@ -128,26 +138,25 @@ class PipelineSchedule(BaseSchedule):
|
||||
# Pipeline schedule just puts data in memory
|
||||
batch_data = super().load_batch(data_iter, to_gpu=False)
|
||||
self.microbatch_offset = 0
|
||||
assert self.batch_size % self.num_microbatches == 0, \
|
||||
"Batch size should divided by the number of microbatches"
|
||||
assert self.batch_size % self.num_microbatches == 0, "Batch size should divided by the number of microbatches"
|
||||
self.microbatch_size = self.batch_size // self.num_microbatches
|
||||
self.batch_data = batch_data
|
||||
|
||||
def _get_data_slice(self, data, offset):
|
||||
if isinstance(data, torch.Tensor):
|
||||
return data[offset:offset + self.microbatch_size]
|
||||
return data[offset : offset + self.microbatch_size]
|
||||
elif isinstance(data, (list, tuple)):
|
||||
data_dict = {}
|
||||
for element in data:
|
||||
if isinstance(element, dict):
|
||||
data_dict.update({k: v[offset:offset + self.microbatch_size] for k, v in element.items()})
|
||||
data_dict.update({k: v[offset : offset + self.microbatch_size] for k, v in element.items()})
|
||||
elif data_dict:
|
||||
data_dict['label'] = element[offset:offset + self.microbatch_size]
|
||||
data_dict["label"] = element[offset : offset + self.microbatch_size]
|
||||
if data_dict:
|
||||
return data_dict
|
||||
return [val[offset:offset + self.microbatch_size] for val in data]
|
||||
return [val[offset : offset + self.microbatch_size] for val in data]
|
||||
elif isinstance(data, dict):
|
||||
return {k: v[offset:offset + self.microbatch_size] for k, v in data.items()}
|
||||
return {k: v[offset : offset + self.microbatch_size] for k, v in data.items()}
|
||||
else:
|
||||
raise TypeError(f"Expected data to be of type torch.Tensor, list, tuple, or dict, but got {type(data)}")
|
||||
|
||||
@@ -180,8 +189,8 @@ class PipelineSchedule(BaseSchedule):
|
||||
return model(*data)
|
||||
elif isinstance(data, dict):
|
||||
stage_output = None
|
||||
if 'stage_output' in data:
|
||||
stage_output = data.pop('stage_output')
|
||||
if "stage_output" in data:
|
||||
stage_output = data.pop("stage_output")
|
||||
if stage_output is None:
|
||||
return model(**data)
|
||||
elif isinstance(stage_output, torch.Tensor):
|
||||
@@ -198,7 +207,7 @@ class PipelineSchedule(BaseSchedule):
|
||||
def _get_actual_forward_func(self, module):
|
||||
if isinstance(module, NaiveAMPModel):
|
||||
sig = inspect.signature(module.model.forward)
|
||||
elif hasattr(module, 'colo_attr'):
|
||||
elif hasattr(module, "colo_attr"):
|
||||
sig = inspect.signature(module.module.forward)
|
||||
else:
|
||||
sig = inspect.signature(module.forward)
|
||||
@@ -221,9 +230,9 @@ class PipelineSchedule(BaseSchedule):
|
||||
_, label = micro_batch_data
|
||||
elif isinstance(micro_batch_data, dict):
|
||||
data = {}
|
||||
data['stage_output'] = stage_output
|
||||
if 'label' in micro_batch_data:
|
||||
label = micro_batch_data.pop('label')
|
||||
data["stage_output"] = stage_output
|
||||
if "label" in micro_batch_data:
|
||||
label = micro_batch_data.pop("label")
|
||||
else:
|
||||
label = None
|
||||
load_data = micro_batch_data
|
||||
@@ -263,7 +272,7 @@ class PipelineSchedule(BaseSchedule):
|
||||
else:
|
||||
if isinstance(output_obj, torch.Tensor):
|
||||
self._logger.debug(
|
||||
f'Global rank {gpc.get_global_rank()}, pipeline rank {gpc.get_local_rank(ParallelMode.PIPELINE)} forward output tensor {output_obj.shape}, dtype {output_obj.dtype}'
|
||||
f"Global rank {gpc.get_global_rank()}, pipeline rank {gpc.get_local_rank(ParallelMode.PIPELINE)} forward output tensor {output_obj.shape}, dtype {output_obj.dtype}"
|
||||
)
|
||||
return output_obj
|
||||
|
||||
@@ -325,12 +334,13 @@ class PipelineSchedule(BaseSchedule):
|
||||
Tuple[:class:`torch.Tensor`]: A tuple of (output, label, loss), loss and label could be None.
|
||||
"""
|
||||
|
||||
assert forward_only or return_loss, \
|
||||
'The argument \'return_loss\' has to be True when \'forward_only\' is False, but got False.'
|
||||
assert (
|
||||
forward_only or return_loss
|
||||
), "The argument 'return_loss' has to be True when 'forward_only' is False, but got False."
|
||||
self.load_batch(data_iter)
|
||||
num_warmup_microbatches = \
|
||||
(gpc.get_world_size(ParallelMode.PIPELINE)
|
||||
- gpc.get_local_rank(ParallelMode.PIPELINE) - 1)
|
||||
num_warmup_microbatches = (
|
||||
gpc.get_world_size(ParallelMode.PIPELINE) - gpc.get_local_rank(ParallelMode.PIPELINE) - 1
|
||||
)
|
||||
num_warmup_microbatches = min(num_warmup_microbatches, self.num_microbatches)
|
||||
num_microbatches_remaining = self.num_microbatches - num_warmup_microbatches
|
||||
|
||||
@@ -354,14 +364,12 @@ class PipelineSchedule(BaseSchedule):
|
||||
for i in range(num_warmup_microbatches):
|
||||
if not gpc.is_first_rank(ParallelMode.PIPELINE):
|
||||
ft_shapes = comm.recv_obj_meta(ft_shapes)
|
||||
input_obj = comm.recv_forward(ft_shapes,
|
||||
dtype=self.dtype,
|
||||
scatter_gather_tensors=self.scatter_gather_tensors)
|
||||
output_obj = self._forward_step(engine,
|
||||
input_obj,
|
||||
return_tensors,
|
||||
return_output_label=return_output_label,
|
||||
accum_loss=accum_loss)
|
||||
input_obj = comm.recv_forward(
|
||||
ft_shapes, dtype=self.dtype, scatter_gather_tensors=self.scatter_gather_tensors
|
||||
)
|
||||
output_obj = self._forward_step(
|
||||
engine, input_obj, return_tensors, return_output_label=return_output_label, accum_loss=accum_loss
|
||||
)
|
||||
if not gpc.is_last_rank(ParallelMode.PIPELINE):
|
||||
if isinstance(output_obj, torch.Tensor):
|
||||
bt_shapes = output_obj.shape
|
||||
@@ -382,32 +390,29 @@ class PipelineSchedule(BaseSchedule):
|
||||
if num_microbatches_remaining > 0:
|
||||
if not gpc.is_first_rank(ParallelMode.PIPELINE):
|
||||
ft_shapes = comm.recv_obj_meta(ft_shapes)
|
||||
input_obj = comm.recv_forward(ft_shapes,
|
||||
dtype=self.dtype,
|
||||
scatter_gather_tensors=self.scatter_gather_tensors)
|
||||
input_obj = comm.recv_forward(
|
||||
ft_shapes, dtype=self.dtype, scatter_gather_tensors=self.scatter_gather_tensors
|
||||
)
|
||||
|
||||
# Run 1F1B in steady state.
|
||||
for i in range(num_microbatches_remaining):
|
||||
last_iteration = (i == (num_microbatches_remaining - 1))
|
||||
last_iteration = i == (num_microbatches_remaining - 1)
|
||||
|
||||
output_obj = self._forward_step(engine,
|
||||
input_obj,
|
||||
return_tensors,
|
||||
return_output_label=return_output_label,
|
||||
accum_loss=accum_loss)
|
||||
output_obj = self._forward_step(
|
||||
engine, input_obj, return_tensors, return_output_label=return_output_label, accum_loss=accum_loss
|
||||
)
|
||||
if forward_only:
|
||||
comm.send_forward(output_obj, scatter_gather_tensors=self.scatter_gather_tensors)
|
||||
|
||||
if not last_iteration:
|
||||
input_obj = comm.recv_forward(ft_shapes,
|
||||
dtype=self.dtype,
|
||||
scatter_gather_tensors=self.scatter_gather_tensors)
|
||||
input_obj = comm.recv_forward(
|
||||
ft_shapes, dtype=self.dtype, scatter_gather_tensors=self.scatter_gather_tensors
|
||||
)
|
||||
|
||||
else:
|
||||
output_obj_grad = comm.send_forward_recv_backward(output_obj,
|
||||
bt_shapes,
|
||||
dtype=self.dtype,
|
||||
scatter_gather_tensors=self.scatter_gather_tensors)
|
||||
output_obj_grad = comm.send_forward_recv_backward(
|
||||
output_obj, bt_shapes, dtype=self.dtype, scatter_gather_tensors=self.scatter_gather_tensors
|
||||
)
|
||||
|
||||
# Add input_obj and output_obj to end of list.
|
||||
input_objs.append(input_obj)
|
||||
@@ -424,10 +429,9 @@ class PipelineSchedule(BaseSchedule):
|
||||
input_obj = None
|
||||
comm.send_backward(input_obj_grad, scatter_gather_tensors=self.scatter_gather_tensors)
|
||||
else:
|
||||
input_obj = comm.send_backward_recv_forward(input_obj_grad,
|
||||
ft_shapes,
|
||||
dtype=self.dtype,
|
||||
scatter_gather_tensors=self.scatter_gather_tensors)
|
||||
input_obj = comm.send_backward_recv_forward(
|
||||
input_obj_grad, ft_shapes, dtype=self.dtype, scatter_gather_tensors=self.scatter_gather_tensors
|
||||
)
|
||||
|
||||
# Run cooldown backward passes.
|
||||
if not forward_only:
|
||||
@@ -435,9 +439,9 @@ class PipelineSchedule(BaseSchedule):
|
||||
input_obj = input_objs.pop(0)
|
||||
output_obj = output_objs.pop(0)
|
||||
|
||||
output_obj_grad = comm.recv_backward(bt_shapes,
|
||||
dtype=self.dtype,
|
||||
scatter_gather_tensors=self.scatter_gather_tensors)
|
||||
output_obj_grad = comm.recv_backward(
|
||||
bt_shapes, dtype=self.dtype, scatter_gather_tensors=self.scatter_gather_tensors
|
||||
)
|
||||
|
||||
input_obj_grad = self._backward_step(engine, input_obj, output_obj, output_obj_grad)
|
||||
|
||||
@@ -451,13 +455,14 @@ class PipelineSchedule(BaseSchedule):
|
||||
|
||||
|
||||
class InterleavedPipelineSchedule(PipelineSchedule):
|
||||
|
||||
def __init__(self,
|
||||
num_microbatches: int,
|
||||
num_model_chunks: int,
|
||||
data_process_func: Callable = None,
|
||||
tensor_shape: Union[torch.Size, List[int], Tuple[int]] = None,
|
||||
scatter_gather_tensors: bool = False):
|
||||
def __init__(
|
||||
self,
|
||||
num_microbatches: int,
|
||||
num_model_chunks: int,
|
||||
data_process_func: Callable = None,
|
||||
tensor_shape: Union[torch.Size, List[int], Tuple[int]] = None,
|
||||
scatter_gather_tensors: bool = False,
|
||||
):
|
||||
"""A helper schedule class for pipeline parallelism running environment.
|
||||
It uses interleaved 1F1B strategy. Other properties are similar as
|
||||
:class:`NonPipelineSchedule`.
|
||||
@@ -471,20 +476,25 @@ class InterleavedPipelineSchedule(PipelineSchedule):
|
||||
scatter_gather_tensors (bool, optional):
|
||||
If set to `True`, communication will be reduced over pipeline when using 1D tensor parallelization.
|
||||
"""
|
||||
assert num_microbatches % gpc.get_world_size(ParallelMode.PIPELINE) == 0, \
|
||||
'num_microbatches must be an integer multiple of pipeline parallel world size'
|
||||
assert isinstance(num_model_chunks, int) and num_model_chunks > 0, \
|
||||
f'expected num_model_chunks to be an integer and larger than 0, but got {num_model_chunks}'
|
||||
super().__init__(num_microbatches,
|
||||
data_process_func=data_process_func,
|
||||
tensor_shape=tensor_shape,
|
||||
scatter_gather_tensors=scatter_gather_tensors)
|
||||
assert (
|
||||
num_microbatches % gpc.get_world_size(ParallelMode.PIPELINE) == 0
|
||||
), "num_microbatches must be an integer multiple of pipeline parallel world size"
|
||||
assert (
|
||||
isinstance(num_model_chunks, int) and num_model_chunks > 0
|
||||
), f"expected num_model_chunks to be an integer and larger than 0, but got {num_model_chunks}"
|
||||
super().__init__(
|
||||
num_microbatches,
|
||||
data_process_func=data_process_func,
|
||||
tensor_shape=tensor_shape,
|
||||
scatter_gather_tensors=scatter_gather_tensors,
|
||||
)
|
||||
gpc.set_virtual_pipeline_parallel_size(num_model_chunks)
|
||||
gpc.set_virtual_pipeline_parallel_rank(0)
|
||||
self.num_model_chunks = num_model_chunks
|
||||
|
||||
def pre_processing(self, engine):
|
||||
from colossalai.zero.sharded_model.sharded_model_v2 import ShardedModelV2
|
||||
|
||||
if isinstance(engine.model, ShardedModelV2):
|
||||
self.dtype = torch.half
|
||||
elif isinstance(engine.model[0], NaiveAMPModel):
|
||||
@@ -494,7 +504,7 @@ class InterleavedPipelineSchedule(PipelineSchedule):
|
||||
model = model.model
|
||||
sig = inspect.signature(model.forward)
|
||||
for p in sig.parameters.values():
|
||||
assert p.kind != inspect.Parameter.VAR_POSITIONAL, '*args is not supported'
|
||||
assert p.kind != inspect.Parameter.VAR_POSITIONAL, "*args is not supported"
|
||||
|
||||
def load_batch(self, data_iter):
|
||||
super().load_batch(data_iter)
|
||||
@@ -506,13 +516,9 @@ class InterleavedPipelineSchedule(PipelineSchedule):
|
||||
self.microbatch_offset[model_chunk_id] += self.microbatch_size
|
||||
return self._move_to_device(data)
|
||||
|
||||
def _forward_step(self,
|
||||
engine,
|
||||
model_chunk_id,
|
||||
input_obj,
|
||||
return_tensors,
|
||||
return_output_label=True,
|
||||
accum_loss=None):
|
||||
def _forward_step(
|
||||
self, engine, model_chunk_id, input_obj, return_tensors, return_output_label=True, accum_loss=None
|
||||
):
|
||||
"""Forward step for passed-in model. If it is the first stage, the input tensor
|
||||
is obtained from data_iterator, otherwise the passed-in input_obj is used.
|
||||
Returns output tensor. This is a helper function and can be ignored by users.
|
||||
@@ -528,8 +534,9 @@ class InterleavedPipelineSchedule(PipelineSchedule):
|
||||
Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]]: output or the loss value of the current pipeline stage.
|
||||
"""
|
||||
micro_batch_data = self.load_micro_batch(model_chunk_id)
|
||||
data, label = self._get_data_label_for_current_step(input_obj, micro_batch_data, engine.criterion,
|
||||
engine.model[model_chunk_id])
|
||||
data, label = self._get_data_label_for_current_step(
|
||||
input_obj, micro_batch_data, engine.criterion, engine.model[model_chunk_id]
|
||||
)
|
||||
|
||||
output_obj = self._call_engine(engine.model[model_chunk_id], data)
|
||||
|
||||
@@ -546,7 +553,7 @@ class InterleavedPipelineSchedule(PipelineSchedule):
|
||||
else:
|
||||
if isinstance(output_obj, torch.Tensor):
|
||||
self._logger.debug(
|
||||
f'Global rank {gpc.get_global_rank()}, pipeline rank {gpc.get_local_rank(ParallelMode.PIPELINE)} forward output tensor {output_obj.shape}, dtype {output_obj.dtype}'
|
||||
f"Global rank {gpc.get_global_rank()}, pipeline rank {gpc.get_local_rank(ParallelMode.PIPELINE)} forward output tensor {output_obj.shape}, dtype {output_obj.dtype}"
|
||||
)
|
||||
return output_obj
|
||||
|
||||
@@ -566,8 +573,9 @@ class InterleavedPipelineSchedule(PipelineSchedule):
|
||||
Tuple[:class:`torch.Tensor`]: A tuple of (output, label, loss), loss and label could be None.
|
||||
The loss would be returned only in the last stage.
|
||||
"""
|
||||
assert forward_only or return_loss, \
|
||||
'The argument \'return_loss\' has to be True when \'forward_only\' is False, but got False.'
|
||||
assert (
|
||||
forward_only or return_loss
|
||||
), "The argument 'return_loss' has to be True when 'forward_only' is False, but got False."
|
||||
self.load_batch(data_iter)
|
||||
model = engine.model
|
||||
input_objs = [[] for _ in range(len(model))]
|
||||
@@ -605,19 +613,17 @@ class InterleavedPipelineSchedule(PipelineSchedule):
|
||||
num_warmup_microbatches = num_microbatches
|
||||
all_warmup_microbatches = True
|
||||
else:
|
||||
num_warmup_microbatches = \
|
||||
(pipeline_parallel_size - pipeline_parallel_rank - 1) * 2
|
||||
num_warmup_microbatches = (pipeline_parallel_size - pipeline_parallel_rank - 1) * 2
|
||||
num_warmup_microbatches += (num_model_chunks - 1) * pipeline_parallel_size
|
||||
num_warmup_microbatches = min(num_warmup_microbatches, num_microbatches)
|
||||
num_microbatches_remaining = \
|
||||
num_microbatches - num_warmup_microbatches
|
||||
num_microbatches_remaining = num_microbatches - num_warmup_microbatches
|
||||
|
||||
def get_model_chunk_id(microbatch_id, forward):
|
||||
"""Helper method to get the model chunk ID given the iteration number."""
|
||||
microbatch_id_in_group = microbatch_id % (pipeline_parallel_size * num_model_chunks)
|
||||
model_chunk_id = microbatch_id_in_group // pipeline_parallel_size
|
||||
if not forward:
|
||||
model_chunk_id = (num_model_chunks - model_chunk_id - 1)
|
||||
model_chunk_id = num_model_chunks - model_chunk_id - 1
|
||||
return model_chunk_id
|
||||
|
||||
def _forward_step_helper(microbatch_id):
|
||||
@@ -629,16 +635,17 @@ class InterleavedPipelineSchedule(PipelineSchedule):
|
||||
|
||||
# forward step
|
||||
if gpc.is_pipeline_first_stage():
|
||||
if len(input_objs[model_chunk_id]) == \
|
||||
len(output_objs[model_chunk_id]):
|
||||
if len(input_objs[model_chunk_id]) == len(output_objs[model_chunk_id]):
|
||||
input_objs[model_chunk_id].append(None)
|
||||
input_obj = input_objs[model_chunk_id][-1]
|
||||
output_obj = self._forward_step(engine,
|
||||
model_chunk_id,
|
||||
input_obj,
|
||||
return_tensors,
|
||||
return_output_label=return_output_label,
|
||||
accum_loss=accum_loss)
|
||||
output_obj = self._forward_step(
|
||||
engine,
|
||||
model_chunk_id,
|
||||
input_obj,
|
||||
return_tensors,
|
||||
return_output_label=return_output_label,
|
||||
accum_loss=accum_loss,
|
||||
)
|
||||
output_objs[model_chunk_id].append(output_obj)
|
||||
|
||||
# if forward-only, no need to save tensors for a backward pass
|
||||
@@ -670,8 +677,8 @@ class InterleavedPipelineSchedule(PipelineSchedule):
|
||||
if not gpc.is_pipeline_first_stage():
|
||||
input_obj_shapes[0] = comm.recv_obj_meta(input_obj_shapes[0])
|
||||
input_objs[0].append(
|
||||
comm.recv_forward(input_obj_shapes[0], dtype=self.dtype,
|
||||
scatter_gather_tensors=self.scatter_gather_tensors))
|
||||
comm.recv_forward(input_obj_shapes[0], dtype=self.dtype, scatter_gather_tensors=self.scatter_gather_tensors)
|
||||
)
|
||||
|
||||
for k in range(num_warmup_microbatches):
|
||||
model_chunk_id = get_model_chunk_id(k, forward=True)
|
||||
@@ -683,8 +690,9 @@ class InterleavedPipelineSchedule(PipelineSchedule):
|
||||
output_obj_shapes[model_chunk_id] = []
|
||||
for out_tensor in output_obj:
|
||||
output_obj_shapes[model_chunk_id].append(out_tensor.shape)
|
||||
send_tensor_shape_flags[model_chunk_id] = comm.send_obj_meta(output_obj,
|
||||
send_tensor_shape_flags[model_chunk_id])
|
||||
send_tensor_shape_flags[model_chunk_id] = comm.send_obj_meta(
|
||||
output_obj, send_tensor_shape_flags[model_chunk_id]
|
||||
)
|
||||
# Determine if tensor should be received from previous stage.
|
||||
next_forward_model_chunk_id = get_model_chunk_id(k + 1, forward=True)
|
||||
recv_prev = True
|
||||
@@ -701,34 +709,36 @@ class InterleavedPipelineSchedule(PipelineSchedule):
|
||||
with switch_virtual_pipeline_parallel_rank(next_forward_model_chunk_id):
|
||||
if not gpc.is_pipeline_first_stage():
|
||||
input_obj_shapes[next_forward_model_chunk_id] = comm.recv_obj_meta(
|
||||
input_obj_shapes[next_forward_model_chunk_id])
|
||||
input_obj_shapes[next_forward_model_chunk_id]
|
||||
)
|
||||
# Send and receive tensors as appropriate (send tensors computed
|
||||
# in this iteration; receive tensors for next iteration).
|
||||
input_shape = input_obj_shapes[next_forward_model_chunk_id] if recv_prev else None
|
||||
if k == (num_warmup_microbatches - 1) and not forward_only and \
|
||||
not all_warmup_microbatches:
|
||||
if k == (num_warmup_microbatches - 1) and not forward_only and not all_warmup_microbatches:
|
||||
input_obj_grad = None
|
||||
recv_next = True
|
||||
if gpc.is_pipeline_last_stage(ignore_virtual=True):
|
||||
recv_next = False
|
||||
output_shape = output_obj_shapes[num_model_chunks - 1] if recv_next else None
|
||||
input_obj, output_obj_grad = \
|
||||
comm.send_forward_backward_recv_forward_backward(
|
||||
output_obj, input_obj_grad,
|
||||
input_shape,
|
||||
output_shape,
|
||||
recv_prev=recv_prev, recv_next=recv_next,
|
||||
dtype=self.dtype,
|
||||
scatter_gather_tensors=self.scatter_gather_tensors)
|
||||
input_obj, output_obj_grad = comm.send_forward_backward_recv_forward_backward(
|
||||
output_obj,
|
||||
input_obj_grad,
|
||||
input_shape,
|
||||
output_shape,
|
||||
recv_prev=recv_prev,
|
||||
recv_next=recv_next,
|
||||
dtype=self.dtype,
|
||||
scatter_gather_tensors=self.scatter_gather_tensors,
|
||||
)
|
||||
output_obj_grads[num_model_chunks - 1].append(output_obj_grad)
|
||||
else:
|
||||
input_obj = \
|
||||
comm.send_forward_recv_forward(
|
||||
output_obj,
|
||||
input_shape,
|
||||
recv_prev=recv_prev,
|
||||
dtype=self.dtype,
|
||||
scatter_gather_tensors=self.scatter_gather_tensors)
|
||||
input_obj = comm.send_forward_recv_forward(
|
||||
output_obj,
|
||||
input_shape,
|
||||
recv_prev=recv_prev,
|
||||
dtype=self.dtype,
|
||||
scatter_gather_tensors=self.scatter_gather_tensors,
|
||||
)
|
||||
input_objs[next_forward_model_chunk_id].append(input_obj)
|
||||
|
||||
# Run 1F1B in steady state.
|
||||
@@ -771,8 +781,9 @@ class InterleavedPipelineSchedule(PipelineSchedule):
|
||||
recv_next = True
|
||||
if gpc.is_pipeline_last_stage(ignore_virtual=True):
|
||||
# Last stage is ahead of first stage by (pipeline_parallel_size - 1).
|
||||
next_backward_model_chunk_id = get_model_chunk_id(backward_k - (pipeline_parallel_size - 1),
|
||||
forward=False)
|
||||
next_backward_model_chunk_id = get_model_chunk_id(
|
||||
backward_k - (pipeline_parallel_size - 1), forward=False
|
||||
)
|
||||
if next_backward_model_chunk_id == 0:
|
||||
recv_next = False
|
||||
next_backward_model_chunk_id -= 1
|
||||
@@ -787,14 +798,16 @@ class InterleavedPipelineSchedule(PipelineSchedule):
|
||||
input_shape = input_obj_shapes[next_forward_model_chunk_id] if recv_prev else None
|
||||
output_shape = output_obj_shapes[next_backward_model_chunk_id] if recv_next else None
|
||||
# Communicate objs.
|
||||
input_obj, output_obj_grad = \
|
||||
comm.send_forward_backward_recv_forward_backward(
|
||||
output_obj, input_obj_grad,
|
||||
input_shape,
|
||||
output_shape,
|
||||
recv_prev=recv_prev, recv_next=recv_next,
|
||||
dtype=self.dtype,
|
||||
scatter_gather_tensors=self.scatter_gather_tensors)
|
||||
input_obj, output_obj_grad = comm.send_forward_backward_recv_forward_backward(
|
||||
output_obj,
|
||||
input_obj_grad,
|
||||
input_shape,
|
||||
output_shape,
|
||||
recv_prev=recv_prev,
|
||||
recv_next=recv_next,
|
||||
dtype=self.dtype,
|
||||
scatter_gather_tensors=self.scatter_gather_tensors,
|
||||
)
|
||||
|
||||
# Put input_obj and output_obj_grad in data structures in the
|
||||
# right location.
|
||||
@@ -807,8 +820,10 @@ class InterleavedPipelineSchedule(PipelineSchedule):
|
||||
if not forward_only:
|
||||
if all_warmup_microbatches:
|
||||
output_obj_grads[num_model_chunks - 1].append(
|
||||
comm.recv_backward(output_obj_shapes[num_model_chunks - 1],
|
||||
scatter_gather_tensors=self.scatter_gather_tensors))
|
||||
comm.recv_backward(
|
||||
output_obj_shapes[num_model_chunks - 1], scatter_gather_tensors=self.scatter_gather_tensors
|
||||
)
|
||||
)
|
||||
for k in range(num_microbatches_remaining, num_microbatches):
|
||||
input_obj_grad = _backward_step_helper(k)
|
||||
next_backward_model_chunk_id = get_model_chunk_id(k + 1, forward=False)
|
||||
@@ -820,11 +835,14 @@ class InterleavedPipelineSchedule(PipelineSchedule):
|
||||
recv_next = False
|
||||
output_shape = output_obj_shapes[next_backward_model_chunk_id] if recv_next else None
|
||||
output_obj_grads[next_backward_model_chunk_id].append(
|
||||
comm.send_backward_recv_backward(input_obj_grad,
|
||||
output_shape,
|
||||
recv_next=recv_next,
|
||||
dtype=self.dtype,
|
||||
scatter_gather_tensors=self.scatter_gather_tensors))
|
||||
comm.send_backward_recv_backward(
|
||||
input_obj_grad,
|
||||
output_shape,
|
||||
recv_next=recv_next,
|
||||
dtype=self.dtype,
|
||||
scatter_gather_tensors=self.scatter_gather_tensors,
|
||||
)
|
||||
)
|
||||
|
||||
if len(return_tensors) > 0:
|
||||
output, label = pack_return_tensors(return_tensors)
|
||||
|
@@ -21,7 +21,7 @@ def pack_return_tensors(return_tensors):
|
||||
elif isinstance(output[0], (list, tuple)):
|
||||
output = tuple(torch.cat(tensors, dim=0) for tensors in zip(*output))
|
||||
else:
|
||||
raise TypeError(f'Output of model must be tensor or list/tuple of tensors')
|
||||
raise TypeError(f"Output of model must be tensor or list/tuple of tensors")
|
||||
if isinstance(label[0], torch.Tensor):
|
||||
label = torch.cat(label, dim=0)
|
||||
else:
|
||||
@@ -59,12 +59,9 @@ class PipelineScheduleV2(PipelineSchedule):
|
||||
|
||||
"""
|
||||
|
||||
def forward_backward_step(self,
|
||||
engine: Engine,
|
||||
data_iter: Iterable,
|
||||
forward_only=False,
|
||||
return_loss=True,
|
||||
return_output_label=True) -> Tuple[torch.Tensor]:
|
||||
def forward_backward_step(
|
||||
self, engine: Engine, data_iter: Iterable, forward_only=False, return_loss=True, return_output_label=True
|
||||
) -> Tuple[torch.Tensor]:
|
||||
"""Runs non-interleaved 1F1B schedule, with communication between pipeline stages.
|
||||
Returns a tuple with losses if the last stage, an empty tuple otherwise.
|
||||
|
||||
@@ -80,14 +77,15 @@ class PipelineScheduleV2(PipelineSchedule):
|
||||
Tuple[:class:`torch.Tensor`]: A tuple of (output, label, loss), loss and label could be None.
|
||||
"""
|
||||
|
||||
assert forward_only or return_loss, \
|
||||
'The argument \'return_loss\' has to be True when \'forward_only\' is False, but got False.'
|
||||
assert (
|
||||
forward_only or return_loss
|
||||
), "The argument 'return_loss' has to be True when 'forward_only' is False, but got False."
|
||||
self.load_batch(data_iter)
|
||||
|
||||
# num_warmup_microbatches is the step when not all the processes are working
|
||||
num_warmup_microbatches = \
|
||||
(gpc.get_world_size(ParallelMode.PIPELINE)
|
||||
- gpc.get_local_rank(ParallelMode.PIPELINE) - 1)
|
||||
num_warmup_microbatches = (
|
||||
gpc.get_world_size(ParallelMode.PIPELINE) - gpc.get_local_rank(ParallelMode.PIPELINE) - 1
|
||||
)
|
||||
num_warmup_microbatches = min(num_warmup_microbatches, self.num_microbatches)
|
||||
num_microbatches_remaining = self.num_microbatches - num_warmup_microbatches
|
||||
|
||||
@@ -109,11 +107,9 @@ class PipelineScheduleV2(PipelineSchedule):
|
||||
for i in range(num_warmup_microbatches):
|
||||
input_obj = comm.recv_forward()
|
||||
|
||||
output_obj = self._forward_step(engine,
|
||||
input_obj,
|
||||
return_tensors,
|
||||
return_output_label=return_output_label,
|
||||
accum_loss=accum_loss)
|
||||
output_obj = self._forward_step(
|
||||
engine, input_obj, return_tensors, return_output_label=return_output_label, accum_loss=accum_loss
|
||||
)
|
||||
|
||||
comm.send_forward(output_obj)
|
||||
|
||||
@@ -129,13 +125,11 @@ class PipelineScheduleV2(PipelineSchedule):
|
||||
|
||||
# Run 1F1B in steady state.
|
||||
for i in range(num_microbatches_remaining):
|
||||
last_iteration = (i == (num_microbatches_remaining - 1))
|
||||
last_iteration = i == (num_microbatches_remaining - 1)
|
||||
|
||||
output_obj = self._forward_step(engine,
|
||||
input_obj,
|
||||
return_tensors,
|
||||
return_output_label=return_output_label,
|
||||
accum_loss=accum_loss)
|
||||
output_obj = self._forward_step(
|
||||
engine, input_obj, return_tensors, return_output_label=return_output_label, accum_loss=accum_loss
|
||||
)
|
||||
if forward_only:
|
||||
comm.send_forward(output_obj)
|
||||
|
||||
|
Reference in New Issue
Block a user