[misc] update pre-commit and run all files (#4752)

* [misc] update pre-commit

* [misc] run pre-commit

* [misc] remove useless configuration files

* [misc] ignore cuda for clang-format
This commit is contained in:
Hongxin Liu
2023-09-19 14:20:26 +08:00
committed by GitHub
parent 3c6b831c26
commit 079bf3cb26
1268 changed files with 50037 additions and 38444 deletions

View File

@@ -2,4 +2,4 @@ from ._base_schedule import BaseSchedule
from ._non_pipeline_schedule import NonPipelineSchedule
from ._pipeline_schedule import InterleavedPipelineSchedule, PipelineSchedule, get_tensor_shape
__all__ = ['BaseSchedule', 'NonPipelineSchedule', 'PipelineSchedule', 'InterleavedPipelineSchedule', 'get_tensor_shape']
__all__ = ["BaseSchedule", "NonPipelineSchedule", "PipelineSchedule", "InterleavedPipelineSchedule", "get_tensor_shape"]

View File

@@ -47,7 +47,8 @@ class BaseSchedule(ABC):
data = {k: self._move_tensor(v) for k, v in data.items()}
else:
raise TypeError(
f"Expected batch data to be of type torch.Tensor, list, tuple, or dict, but got {type(data)}")
f"Expected batch data to be of type torch.Tensor, list, tuple, or dict, but got {type(data)}"
)
return data
def _get_batch_size(self, data):
@@ -72,7 +73,7 @@ class BaseSchedule(ABC):
Tuple (:class:`Tensor`, :class:`torch.Tensor`): A tuple of (data, label).
"""
if data_iter is None:
raise RuntimeError('Dataloader is not defined.')
raise RuntimeError("Dataloader is not defined.")
batch_data = next(data_iter)
if to_gpu:
@@ -81,17 +82,17 @@ class BaseSchedule(ABC):
return batch_data
def pre_processing(self, engine):
"""To perform actions before running the schedule.
"""
pass
"""To perform actions before running the schedule."""
@abstractmethod
def forward_backward_step(self,
engine,
data_iter: Iterable,
forward_only: bool,
return_loss: bool = True,
return_output_label: bool = True):
def forward_backward_step(
self,
engine,
data_iter: Iterable,
forward_only: bool,
return_loss: bool = True,
return_output_label: bool = True,
):
"""The process function over a batch of dataset for training or evaluation.
Args:
@@ -101,7 +102,6 @@ class BaseSchedule(ABC):
return_loss (bool, optional): If False, the loss won't be returned.
return_output_label (bool, optional): If False, the output and label won't be returned.
"""
pass
@staticmethod
def _call_engine(engine, inputs):
@@ -113,13 +113,14 @@ class BaseSchedule(ABC):
return engine(**inputs)
else:
TypeError(
f"Expected engine inputs to be of type torch.Tensor, list, tuple, or dict, but got {type(inputs)}")
f"Expected engine inputs to be of type torch.Tensor, list, tuple, or dict, but got {type(inputs)}"
)
@staticmethod
def _call_engine_criterion(engine, outputs, labels):
assert isinstance(outputs,
(torch.Tensor, list, tuple,
dict)), f'Expect output of model is (torch.Tensor, list, tuple), got {type(outputs)}'
assert isinstance(
outputs, (torch.Tensor, list, tuple, dict)
), f"Expect output of model is (torch.Tensor, list, tuple), got {type(outputs)}"
if isinstance(outputs, torch.Tensor):
outputs = (outputs,)
if isinstance(labels, torch.Tensor):
@@ -134,6 +135,8 @@ class BaseSchedule(ABC):
elif isinstance(outputs, dict) and isinstance(labels, (list, tuple)):
raise ValueError(f"Expected labels to be a dict when the model outputs are dict, but got {type(labels)}")
else:
raise TypeError(f"Expected model outputs and labels to be of type torch.Tensor ' \
raise TypeError(
f"Expected model outputs and labels to be of type torch.Tensor ' \
'(which is auto-converted to tuple), list, tuple, or dict, ' \
'but got {type(outputs)} (model outputs) and {type(labels)} (labels)")
'but got {type(outputs)} (model outputs) and {type(labels)} (labels)"
)

View File

@@ -37,19 +37,22 @@ class NonPipelineSchedule(BaseSchedule):
if data_process_func:
sig = inspect.signature(data_process_func)
assert len(sig.parameters) == 1, \
'The data_process_func only takes in one parameter for NonPipelineSchedule, ' \
'which is a tuple of tensors for the current batch, ' \
'i.e. data_process_func(dataloader_output).'
assert len(sig.parameters) == 1, (
"The data_process_func only takes in one parameter for NonPipelineSchedule, "
"which is a tuple of tensors for the current batch, "
"i.e. data_process_func(dataloader_output)."
)
super().__init__(data_process_func)
def forward_backward_step(self,
engine,
data_iter: Iterable,
forward_only: bool = False,
return_loss: bool = True,
return_output_label: bool = True):
def forward_backward_step(
self,
engine,
data_iter: Iterable,
forward_only: bool = False,
return_loss: bool = True,
return_output_label: bool = True,
):
"""The process function that loads a batch of dataset and feeds it to the model.
The returned labels and loss will None if :attr:`return_loss` is False.
@@ -64,8 +67,9 @@ class NonPipelineSchedule(BaseSchedule):
Returns:
Tuple[:class:`torch.Tensor`]: A tuple of (output, label, loss), loss and label could be None.
"""
assert forward_only or return_loss, \
"The argument 'return_loss' has to be True when 'forward_only' is False, but got False."
assert (
forward_only or return_loss
), "The argument 'return_loss' has to be True when 'forward_only' is False, but got False."
batch_data = self.load_batch(data_iter)
if self.data_process_func:
data, label = self.data_process_func(batch_data)

View File

@@ -18,14 +18,18 @@ from ._base_schedule import BaseSchedule
def get_tensor_shape():
if hasattr(gpc.config, 'TENSOR_SHAPE'):
if hasattr(gpc.config, "TENSOR_SHAPE"):
return gpc.config.TENSOR_SHAPE
if not gpc.is_initialized(ParallelMode.PIPELINE):
return None
if hasattr(gpc.config, 'SEQ_LENGTH') and hasattr(gpc.config, 'GLOBAL_BATCH_SIZE') and hasattr(
gpc.config, 'GLOBAL_BATCH_SIZE') and hasattr(gpc.config, 'HIDDEN_SIZE'):
if (
hasattr(gpc.config, "SEQ_LENGTH")
and hasattr(gpc.config, "GLOBAL_BATCH_SIZE")
and hasattr(gpc.config, "GLOBAL_BATCH_SIZE")
and hasattr(gpc.config, "HIDDEN_SIZE")
):
if gpc.is_initialized(ParallelMode.DATA):
dp_size = gpc.get_world_size(ParallelMode.DATA)
else:
@@ -35,8 +39,11 @@ def get_tensor_shape():
else:
seq_size = 1
tensor_shape = (gpc.config.SEQ_LENGTH // seq_size,
gpc.config.GLOBAL_BATCH_SIZE // dp_size // gpc.config.NUM_MICRO_BATCHES, gpc.config.HIDDEN_SIZE)
tensor_shape = (
gpc.config.SEQ_LENGTH // seq_size,
gpc.config.GLOBAL_BATCH_SIZE // dp_size // gpc.config.NUM_MICRO_BATCHES,
gpc.config.HIDDEN_SIZE,
)
return tensor_shape
else:
return None
@@ -49,7 +56,7 @@ def pack_return_tensors(return_tensors):
elif isinstance(output[0], (list, tuple)):
output = tuple(torch.cat(tensors, dim=0) for tensors in zip(*output))
else:
raise TypeError(f'Output of model must be tensor or list/tuple of tensors')
raise TypeError(f"Output of model must be tensor or list/tuple of tensors")
if isinstance(label[0], torch.Tensor):
label = torch.cat(label, dim=0)
else:
@@ -88,28 +95,31 @@ class PipelineSchedule(BaseSchedule):
"""
def __init__(self,
num_microbatches,
data_process_func: Callable = None,
tensor_shape: Union[torch.Size, List[int], Tuple[int]] = None,
scatter_gather_tensors: bool = False):
def __init__(
self,
num_microbatches,
data_process_func: Callable = None,
tensor_shape: Union[torch.Size, List[int], Tuple[int]] = None,
scatter_gather_tensors: bool = False,
):
# we need to make sure that the signature of the data_process_func is valid
if data_process_func:
sig = inspect.signature(data_process_func)
assert len(sig.parameters) == 2, \
'The data_process_func only takes in two parameters for NonPipelineSchedule, ' \
'which is the tensors passed by the previous pipeline stage and the dataloader output from this stage, ' \
'i.e. data_process_func(stage_output, dataloader_output).'
assert len(sig.parameters) == 2, (
"The data_process_func only takes in two parameters for NonPipelineSchedule, "
"which is the tensors passed by the previous pipeline stage and the dataloader output from this stage, "
"i.e. data_process_func(stage_output, dataloader_output)."
)
super().__init__(data_process_func=data_process_func)
assert num_microbatches > 0, f'expected num_microbatches to be larger then 1, but got {num_microbatches}'
assert num_microbatches > 0, f"expected num_microbatches to be larger then 1, but got {num_microbatches}"
self.num_microbatches = num_microbatches
self.dtype = torch.float
assert not isinstance(tensor_shape,
int), "tensor_shape type should be one of Union[torch.Size, List[int], Tuple[int]]."
assert not isinstance(
tensor_shape, int
), "tensor_shape type should be one of Union[torch.Size, List[int], Tuple[int]]."
if tensor_shape is None:
self.tensor_shape = tensor_shape
elif isinstance(tensor_shape, torch.Size):
@@ -128,26 +138,25 @@ class PipelineSchedule(BaseSchedule):
# Pipeline schedule just puts data in memory
batch_data = super().load_batch(data_iter, to_gpu=False)
self.microbatch_offset = 0
assert self.batch_size % self.num_microbatches == 0, \
"Batch size should divided by the number of microbatches"
assert self.batch_size % self.num_microbatches == 0, "Batch size should divided by the number of microbatches"
self.microbatch_size = self.batch_size // self.num_microbatches
self.batch_data = batch_data
def _get_data_slice(self, data, offset):
if isinstance(data, torch.Tensor):
return data[offset:offset + self.microbatch_size]
return data[offset : offset + self.microbatch_size]
elif isinstance(data, (list, tuple)):
data_dict = {}
for element in data:
if isinstance(element, dict):
data_dict.update({k: v[offset:offset + self.microbatch_size] for k, v in element.items()})
data_dict.update({k: v[offset : offset + self.microbatch_size] for k, v in element.items()})
elif data_dict:
data_dict['label'] = element[offset:offset + self.microbatch_size]
data_dict["label"] = element[offset : offset + self.microbatch_size]
if data_dict:
return data_dict
return [val[offset:offset + self.microbatch_size] for val in data]
return [val[offset : offset + self.microbatch_size] for val in data]
elif isinstance(data, dict):
return {k: v[offset:offset + self.microbatch_size] for k, v in data.items()}
return {k: v[offset : offset + self.microbatch_size] for k, v in data.items()}
else:
raise TypeError(f"Expected data to be of type torch.Tensor, list, tuple, or dict, but got {type(data)}")
@@ -180,8 +189,8 @@ class PipelineSchedule(BaseSchedule):
return model(*data)
elif isinstance(data, dict):
stage_output = None
if 'stage_output' in data:
stage_output = data.pop('stage_output')
if "stage_output" in data:
stage_output = data.pop("stage_output")
if stage_output is None:
return model(**data)
elif isinstance(stage_output, torch.Tensor):
@@ -198,7 +207,7 @@ class PipelineSchedule(BaseSchedule):
def _get_actual_forward_func(self, module):
if isinstance(module, NaiveAMPModel):
sig = inspect.signature(module.model.forward)
elif hasattr(module, 'colo_attr'):
elif hasattr(module, "colo_attr"):
sig = inspect.signature(module.module.forward)
else:
sig = inspect.signature(module.forward)
@@ -221,9 +230,9 @@ class PipelineSchedule(BaseSchedule):
_, label = micro_batch_data
elif isinstance(micro_batch_data, dict):
data = {}
data['stage_output'] = stage_output
if 'label' in micro_batch_data:
label = micro_batch_data.pop('label')
data["stage_output"] = stage_output
if "label" in micro_batch_data:
label = micro_batch_data.pop("label")
else:
label = None
load_data = micro_batch_data
@@ -263,7 +272,7 @@ class PipelineSchedule(BaseSchedule):
else:
if isinstance(output_obj, torch.Tensor):
self._logger.debug(
f'Global rank {gpc.get_global_rank()}, pipeline rank {gpc.get_local_rank(ParallelMode.PIPELINE)} forward output tensor {output_obj.shape}, dtype {output_obj.dtype}'
f"Global rank {gpc.get_global_rank()}, pipeline rank {gpc.get_local_rank(ParallelMode.PIPELINE)} forward output tensor {output_obj.shape}, dtype {output_obj.dtype}"
)
return output_obj
@@ -325,12 +334,13 @@ class PipelineSchedule(BaseSchedule):
Tuple[:class:`torch.Tensor`]: A tuple of (output, label, loss), loss and label could be None.
"""
assert forward_only or return_loss, \
'The argument \'return_loss\' has to be True when \'forward_only\' is False, but got False.'
assert (
forward_only or return_loss
), "The argument 'return_loss' has to be True when 'forward_only' is False, but got False."
self.load_batch(data_iter)
num_warmup_microbatches = \
(gpc.get_world_size(ParallelMode.PIPELINE)
- gpc.get_local_rank(ParallelMode.PIPELINE) - 1)
num_warmup_microbatches = (
gpc.get_world_size(ParallelMode.PIPELINE) - gpc.get_local_rank(ParallelMode.PIPELINE) - 1
)
num_warmup_microbatches = min(num_warmup_microbatches, self.num_microbatches)
num_microbatches_remaining = self.num_microbatches - num_warmup_microbatches
@@ -354,14 +364,12 @@ class PipelineSchedule(BaseSchedule):
for i in range(num_warmup_microbatches):
if not gpc.is_first_rank(ParallelMode.PIPELINE):
ft_shapes = comm.recv_obj_meta(ft_shapes)
input_obj = comm.recv_forward(ft_shapes,
dtype=self.dtype,
scatter_gather_tensors=self.scatter_gather_tensors)
output_obj = self._forward_step(engine,
input_obj,
return_tensors,
return_output_label=return_output_label,
accum_loss=accum_loss)
input_obj = comm.recv_forward(
ft_shapes, dtype=self.dtype, scatter_gather_tensors=self.scatter_gather_tensors
)
output_obj = self._forward_step(
engine, input_obj, return_tensors, return_output_label=return_output_label, accum_loss=accum_loss
)
if not gpc.is_last_rank(ParallelMode.PIPELINE):
if isinstance(output_obj, torch.Tensor):
bt_shapes = output_obj.shape
@@ -382,32 +390,29 @@ class PipelineSchedule(BaseSchedule):
if num_microbatches_remaining > 0:
if not gpc.is_first_rank(ParallelMode.PIPELINE):
ft_shapes = comm.recv_obj_meta(ft_shapes)
input_obj = comm.recv_forward(ft_shapes,
dtype=self.dtype,
scatter_gather_tensors=self.scatter_gather_tensors)
input_obj = comm.recv_forward(
ft_shapes, dtype=self.dtype, scatter_gather_tensors=self.scatter_gather_tensors
)
# Run 1F1B in steady state.
for i in range(num_microbatches_remaining):
last_iteration = (i == (num_microbatches_remaining - 1))
last_iteration = i == (num_microbatches_remaining - 1)
output_obj = self._forward_step(engine,
input_obj,
return_tensors,
return_output_label=return_output_label,
accum_loss=accum_loss)
output_obj = self._forward_step(
engine, input_obj, return_tensors, return_output_label=return_output_label, accum_loss=accum_loss
)
if forward_only:
comm.send_forward(output_obj, scatter_gather_tensors=self.scatter_gather_tensors)
if not last_iteration:
input_obj = comm.recv_forward(ft_shapes,
dtype=self.dtype,
scatter_gather_tensors=self.scatter_gather_tensors)
input_obj = comm.recv_forward(
ft_shapes, dtype=self.dtype, scatter_gather_tensors=self.scatter_gather_tensors
)
else:
output_obj_grad = comm.send_forward_recv_backward(output_obj,
bt_shapes,
dtype=self.dtype,
scatter_gather_tensors=self.scatter_gather_tensors)
output_obj_grad = comm.send_forward_recv_backward(
output_obj, bt_shapes, dtype=self.dtype, scatter_gather_tensors=self.scatter_gather_tensors
)
# Add input_obj and output_obj to end of list.
input_objs.append(input_obj)
@@ -424,10 +429,9 @@ class PipelineSchedule(BaseSchedule):
input_obj = None
comm.send_backward(input_obj_grad, scatter_gather_tensors=self.scatter_gather_tensors)
else:
input_obj = comm.send_backward_recv_forward(input_obj_grad,
ft_shapes,
dtype=self.dtype,
scatter_gather_tensors=self.scatter_gather_tensors)
input_obj = comm.send_backward_recv_forward(
input_obj_grad, ft_shapes, dtype=self.dtype, scatter_gather_tensors=self.scatter_gather_tensors
)
# Run cooldown backward passes.
if not forward_only:
@@ -435,9 +439,9 @@ class PipelineSchedule(BaseSchedule):
input_obj = input_objs.pop(0)
output_obj = output_objs.pop(0)
output_obj_grad = comm.recv_backward(bt_shapes,
dtype=self.dtype,
scatter_gather_tensors=self.scatter_gather_tensors)
output_obj_grad = comm.recv_backward(
bt_shapes, dtype=self.dtype, scatter_gather_tensors=self.scatter_gather_tensors
)
input_obj_grad = self._backward_step(engine, input_obj, output_obj, output_obj_grad)
@@ -451,13 +455,14 @@ class PipelineSchedule(BaseSchedule):
class InterleavedPipelineSchedule(PipelineSchedule):
def __init__(self,
num_microbatches: int,
num_model_chunks: int,
data_process_func: Callable = None,
tensor_shape: Union[torch.Size, List[int], Tuple[int]] = None,
scatter_gather_tensors: bool = False):
def __init__(
self,
num_microbatches: int,
num_model_chunks: int,
data_process_func: Callable = None,
tensor_shape: Union[torch.Size, List[int], Tuple[int]] = None,
scatter_gather_tensors: bool = False,
):
"""A helper schedule class for pipeline parallelism running environment.
It uses interleaved 1F1B strategy. Other properties are similar as
:class:`NonPipelineSchedule`.
@@ -471,20 +476,25 @@ class InterleavedPipelineSchedule(PipelineSchedule):
scatter_gather_tensors (bool, optional):
If set to `True`, communication will be reduced over pipeline when using 1D tensor parallelization.
"""
assert num_microbatches % gpc.get_world_size(ParallelMode.PIPELINE) == 0, \
'num_microbatches must be an integer multiple of pipeline parallel world size'
assert isinstance(num_model_chunks, int) and num_model_chunks > 0, \
f'expected num_model_chunks to be an integer and larger than 0, but got {num_model_chunks}'
super().__init__(num_microbatches,
data_process_func=data_process_func,
tensor_shape=tensor_shape,
scatter_gather_tensors=scatter_gather_tensors)
assert (
num_microbatches % gpc.get_world_size(ParallelMode.PIPELINE) == 0
), "num_microbatches must be an integer multiple of pipeline parallel world size"
assert (
isinstance(num_model_chunks, int) and num_model_chunks > 0
), f"expected num_model_chunks to be an integer and larger than 0, but got {num_model_chunks}"
super().__init__(
num_microbatches,
data_process_func=data_process_func,
tensor_shape=tensor_shape,
scatter_gather_tensors=scatter_gather_tensors,
)
gpc.set_virtual_pipeline_parallel_size(num_model_chunks)
gpc.set_virtual_pipeline_parallel_rank(0)
self.num_model_chunks = num_model_chunks
def pre_processing(self, engine):
from colossalai.zero.sharded_model.sharded_model_v2 import ShardedModelV2
if isinstance(engine.model, ShardedModelV2):
self.dtype = torch.half
elif isinstance(engine.model[0], NaiveAMPModel):
@@ -494,7 +504,7 @@ class InterleavedPipelineSchedule(PipelineSchedule):
model = model.model
sig = inspect.signature(model.forward)
for p in sig.parameters.values():
assert p.kind != inspect.Parameter.VAR_POSITIONAL, '*args is not supported'
assert p.kind != inspect.Parameter.VAR_POSITIONAL, "*args is not supported"
def load_batch(self, data_iter):
super().load_batch(data_iter)
@@ -506,13 +516,9 @@ class InterleavedPipelineSchedule(PipelineSchedule):
self.microbatch_offset[model_chunk_id] += self.microbatch_size
return self._move_to_device(data)
def _forward_step(self,
engine,
model_chunk_id,
input_obj,
return_tensors,
return_output_label=True,
accum_loss=None):
def _forward_step(
self, engine, model_chunk_id, input_obj, return_tensors, return_output_label=True, accum_loss=None
):
"""Forward step for passed-in model. If it is the first stage, the input tensor
is obtained from data_iterator, otherwise the passed-in input_obj is used.
Returns output tensor. This is a helper function and can be ignored by users.
@@ -528,8 +534,9 @@ class InterleavedPipelineSchedule(PipelineSchedule):
Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]]: output or the loss value of the current pipeline stage.
"""
micro_batch_data = self.load_micro_batch(model_chunk_id)
data, label = self._get_data_label_for_current_step(input_obj, micro_batch_data, engine.criterion,
engine.model[model_chunk_id])
data, label = self._get_data_label_for_current_step(
input_obj, micro_batch_data, engine.criterion, engine.model[model_chunk_id]
)
output_obj = self._call_engine(engine.model[model_chunk_id], data)
@@ -546,7 +553,7 @@ class InterleavedPipelineSchedule(PipelineSchedule):
else:
if isinstance(output_obj, torch.Tensor):
self._logger.debug(
f'Global rank {gpc.get_global_rank()}, pipeline rank {gpc.get_local_rank(ParallelMode.PIPELINE)} forward output tensor {output_obj.shape}, dtype {output_obj.dtype}'
f"Global rank {gpc.get_global_rank()}, pipeline rank {gpc.get_local_rank(ParallelMode.PIPELINE)} forward output tensor {output_obj.shape}, dtype {output_obj.dtype}"
)
return output_obj
@@ -566,8 +573,9 @@ class InterleavedPipelineSchedule(PipelineSchedule):
Tuple[:class:`torch.Tensor`]: A tuple of (output, label, loss), loss and label could be None.
The loss would be returned only in the last stage.
"""
assert forward_only or return_loss, \
'The argument \'return_loss\' has to be True when \'forward_only\' is False, but got False.'
assert (
forward_only or return_loss
), "The argument 'return_loss' has to be True when 'forward_only' is False, but got False."
self.load_batch(data_iter)
model = engine.model
input_objs = [[] for _ in range(len(model))]
@@ -605,19 +613,17 @@ class InterleavedPipelineSchedule(PipelineSchedule):
num_warmup_microbatches = num_microbatches
all_warmup_microbatches = True
else:
num_warmup_microbatches = \
(pipeline_parallel_size - pipeline_parallel_rank - 1) * 2
num_warmup_microbatches = (pipeline_parallel_size - pipeline_parallel_rank - 1) * 2
num_warmup_microbatches += (num_model_chunks - 1) * pipeline_parallel_size
num_warmup_microbatches = min(num_warmup_microbatches, num_microbatches)
num_microbatches_remaining = \
num_microbatches - num_warmup_microbatches
num_microbatches_remaining = num_microbatches - num_warmup_microbatches
def get_model_chunk_id(microbatch_id, forward):
"""Helper method to get the model chunk ID given the iteration number."""
microbatch_id_in_group = microbatch_id % (pipeline_parallel_size * num_model_chunks)
model_chunk_id = microbatch_id_in_group // pipeline_parallel_size
if not forward:
model_chunk_id = (num_model_chunks - model_chunk_id - 1)
model_chunk_id = num_model_chunks - model_chunk_id - 1
return model_chunk_id
def _forward_step_helper(microbatch_id):
@@ -629,16 +635,17 @@ class InterleavedPipelineSchedule(PipelineSchedule):
# forward step
if gpc.is_pipeline_first_stage():
if len(input_objs[model_chunk_id]) == \
len(output_objs[model_chunk_id]):
if len(input_objs[model_chunk_id]) == len(output_objs[model_chunk_id]):
input_objs[model_chunk_id].append(None)
input_obj = input_objs[model_chunk_id][-1]
output_obj = self._forward_step(engine,
model_chunk_id,
input_obj,
return_tensors,
return_output_label=return_output_label,
accum_loss=accum_loss)
output_obj = self._forward_step(
engine,
model_chunk_id,
input_obj,
return_tensors,
return_output_label=return_output_label,
accum_loss=accum_loss,
)
output_objs[model_chunk_id].append(output_obj)
# if forward-only, no need to save tensors for a backward pass
@@ -670,8 +677,8 @@ class InterleavedPipelineSchedule(PipelineSchedule):
if not gpc.is_pipeline_first_stage():
input_obj_shapes[0] = comm.recv_obj_meta(input_obj_shapes[0])
input_objs[0].append(
comm.recv_forward(input_obj_shapes[0], dtype=self.dtype,
scatter_gather_tensors=self.scatter_gather_tensors))
comm.recv_forward(input_obj_shapes[0], dtype=self.dtype, scatter_gather_tensors=self.scatter_gather_tensors)
)
for k in range(num_warmup_microbatches):
model_chunk_id = get_model_chunk_id(k, forward=True)
@@ -683,8 +690,9 @@ class InterleavedPipelineSchedule(PipelineSchedule):
output_obj_shapes[model_chunk_id] = []
for out_tensor in output_obj:
output_obj_shapes[model_chunk_id].append(out_tensor.shape)
send_tensor_shape_flags[model_chunk_id] = comm.send_obj_meta(output_obj,
send_tensor_shape_flags[model_chunk_id])
send_tensor_shape_flags[model_chunk_id] = comm.send_obj_meta(
output_obj, send_tensor_shape_flags[model_chunk_id]
)
# Determine if tensor should be received from previous stage.
next_forward_model_chunk_id = get_model_chunk_id(k + 1, forward=True)
recv_prev = True
@@ -701,34 +709,36 @@ class InterleavedPipelineSchedule(PipelineSchedule):
with switch_virtual_pipeline_parallel_rank(next_forward_model_chunk_id):
if not gpc.is_pipeline_first_stage():
input_obj_shapes[next_forward_model_chunk_id] = comm.recv_obj_meta(
input_obj_shapes[next_forward_model_chunk_id])
input_obj_shapes[next_forward_model_chunk_id]
)
# Send and receive tensors as appropriate (send tensors computed
# in this iteration; receive tensors for next iteration).
input_shape = input_obj_shapes[next_forward_model_chunk_id] if recv_prev else None
if k == (num_warmup_microbatches - 1) and not forward_only and \
not all_warmup_microbatches:
if k == (num_warmup_microbatches - 1) and not forward_only and not all_warmup_microbatches:
input_obj_grad = None
recv_next = True
if gpc.is_pipeline_last_stage(ignore_virtual=True):
recv_next = False
output_shape = output_obj_shapes[num_model_chunks - 1] if recv_next else None
input_obj, output_obj_grad = \
comm.send_forward_backward_recv_forward_backward(
output_obj, input_obj_grad,
input_shape,
output_shape,
recv_prev=recv_prev, recv_next=recv_next,
dtype=self.dtype,
scatter_gather_tensors=self.scatter_gather_tensors)
input_obj, output_obj_grad = comm.send_forward_backward_recv_forward_backward(
output_obj,
input_obj_grad,
input_shape,
output_shape,
recv_prev=recv_prev,
recv_next=recv_next,
dtype=self.dtype,
scatter_gather_tensors=self.scatter_gather_tensors,
)
output_obj_grads[num_model_chunks - 1].append(output_obj_grad)
else:
input_obj = \
comm.send_forward_recv_forward(
output_obj,
input_shape,
recv_prev=recv_prev,
dtype=self.dtype,
scatter_gather_tensors=self.scatter_gather_tensors)
input_obj = comm.send_forward_recv_forward(
output_obj,
input_shape,
recv_prev=recv_prev,
dtype=self.dtype,
scatter_gather_tensors=self.scatter_gather_tensors,
)
input_objs[next_forward_model_chunk_id].append(input_obj)
# Run 1F1B in steady state.
@@ -771,8 +781,9 @@ class InterleavedPipelineSchedule(PipelineSchedule):
recv_next = True
if gpc.is_pipeline_last_stage(ignore_virtual=True):
# Last stage is ahead of first stage by (pipeline_parallel_size - 1).
next_backward_model_chunk_id = get_model_chunk_id(backward_k - (pipeline_parallel_size - 1),
forward=False)
next_backward_model_chunk_id = get_model_chunk_id(
backward_k - (pipeline_parallel_size - 1), forward=False
)
if next_backward_model_chunk_id == 0:
recv_next = False
next_backward_model_chunk_id -= 1
@@ -787,14 +798,16 @@ class InterleavedPipelineSchedule(PipelineSchedule):
input_shape = input_obj_shapes[next_forward_model_chunk_id] if recv_prev else None
output_shape = output_obj_shapes[next_backward_model_chunk_id] if recv_next else None
# Communicate objs.
input_obj, output_obj_grad = \
comm.send_forward_backward_recv_forward_backward(
output_obj, input_obj_grad,
input_shape,
output_shape,
recv_prev=recv_prev, recv_next=recv_next,
dtype=self.dtype,
scatter_gather_tensors=self.scatter_gather_tensors)
input_obj, output_obj_grad = comm.send_forward_backward_recv_forward_backward(
output_obj,
input_obj_grad,
input_shape,
output_shape,
recv_prev=recv_prev,
recv_next=recv_next,
dtype=self.dtype,
scatter_gather_tensors=self.scatter_gather_tensors,
)
# Put input_obj and output_obj_grad in data structures in the
# right location.
@@ -807,8 +820,10 @@ class InterleavedPipelineSchedule(PipelineSchedule):
if not forward_only:
if all_warmup_microbatches:
output_obj_grads[num_model_chunks - 1].append(
comm.recv_backward(output_obj_shapes[num_model_chunks - 1],
scatter_gather_tensors=self.scatter_gather_tensors))
comm.recv_backward(
output_obj_shapes[num_model_chunks - 1], scatter_gather_tensors=self.scatter_gather_tensors
)
)
for k in range(num_microbatches_remaining, num_microbatches):
input_obj_grad = _backward_step_helper(k)
next_backward_model_chunk_id = get_model_chunk_id(k + 1, forward=False)
@@ -820,11 +835,14 @@ class InterleavedPipelineSchedule(PipelineSchedule):
recv_next = False
output_shape = output_obj_shapes[next_backward_model_chunk_id] if recv_next else None
output_obj_grads[next_backward_model_chunk_id].append(
comm.send_backward_recv_backward(input_obj_grad,
output_shape,
recv_next=recv_next,
dtype=self.dtype,
scatter_gather_tensors=self.scatter_gather_tensors))
comm.send_backward_recv_backward(
input_obj_grad,
output_shape,
recv_next=recv_next,
dtype=self.dtype,
scatter_gather_tensors=self.scatter_gather_tensors,
)
)
if len(return_tensors) > 0:
output, label = pack_return_tensors(return_tensors)

View File

@@ -21,7 +21,7 @@ def pack_return_tensors(return_tensors):
elif isinstance(output[0], (list, tuple)):
output = tuple(torch.cat(tensors, dim=0) for tensors in zip(*output))
else:
raise TypeError(f'Output of model must be tensor or list/tuple of tensors')
raise TypeError(f"Output of model must be tensor or list/tuple of tensors")
if isinstance(label[0], torch.Tensor):
label = torch.cat(label, dim=0)
else:
@@ -59,12 +59,9 @@ class PipelineScheduleV2(PipelineSchedule):
"""
def forward_backward_step(self,
engine: Engine,
data_iter: Iterable,
forward_only=False,
return_loss=True,
return_output_label=True) -> Tuple[torch.Tensor]:
def forward_backward_step(
self, engine: Engine, data_iter: Iterable, forward_only=False, return_loss=True, return_output_label=True
) -> Tuple[torch.Tensor]:
"""Runs non-interleaved 1F1B schedule, with communication between pipeline stages.
Returns a tuple with losses if the last stage, an empty tuple otherwise.
@@ -80,14 +77,15 @@ class PipelineScheduleV2(PipelineSchedule):
Tuple[:class:`torch.Tensor`]: A tuple of (output, label, loss), loss and label could be None.
"""
assert forward_only or return_loss, \
'The argument \'return_loss\' has to be True when \'forward_only\' is False, but got False.'
assert (
forward_only or return_loss
), "The argument 'return_loss' has to be True when 'forward_only' is False, but got False."
self.load_batch(data_iter)
# num_warmup_microbatches is the step when not all the processes are working
num_warmup_microbatches = \
(gpc.get_world_size(ParallelMode.PIPELINE)
- gpc.get_local_rank(ParallelMode.PIPELINE) - 1)
num_warmup_microbatches = (
gpc.get_world_size(ParallelMode.PIPELINE) - gpc.get_local_rank(ParallelMode.PIPELINE) - 1
)
num_warmup_microbatches = min(num_warmup_microbatches, self.num_microbatches)
num_microbatches_remaining = self.num_microbatches - num_warmup_microbatches
@@ -109,11 +107,9 @@ class PipelineScheduleV2(PipelineSchedule):
for i in range(num_warmup_microbatches):
input_obj = comm.recv_forward()
output_obj = self._forward_step(engine,
input_obj,
return_tensors,
return_output_label=return_output_label,
accum_loss=accum_loss)
output_obj = self._forward_step(
engine, input_obj, return_tensors, return_output_label=return_output_label, accum_loss=accum_loss
)
comm.send_forward(output_obj)
@@ -129,13 +125,11 @@ class PipelineScheduleV2(PipelineSchedule):
# Run 1F1B in steady state.
for i in range(num_microbatches_remaining):
last_iteration = (i == (num_microbatches_remaining - 1))
last_iteration = i == (num_microbatches_remaining - 1)
output_obj = self._forward_step(engine,
input_obj,
return_tensors,
return_output_label=return_output_label,
accum_loss=accum_loss)
output_obj = self._forward_step(
engine, input_obj, return_tensors, return_output_label=return_output_label, accum_loss=accum_loss
)
if forward_only:
comm.send_forward(output_obj)