[doc] improved docstring and assertion messages for the engine module (#871)

This commit is contained in:
Frank Lee
2022-04-26 10:00:18 +08:00
committed by GitHub
parent 1c34382678
commit 11f54c7b6b
9 changed files with 180 additions and 60 deletions

View File

@@ -81,6 +81,9 @@ class PipelineSchedule(BaseSchedule):
tensor_shape: Union[torch.Size, List[int], Tuple[int]] = None,
scatter_gather_tensors: bool = False):
super().__init__(batch_data_process_func=batch_data_process_func)
assert num_microbatches > 0, f'expected num_microbatches to be larger then 1, but got {num_microbatches}'
self.num_microbatches = num_microbatches
self.dtype = torch.float
self.tensor_shape = tensor_shape
@@ -150,7 +153,7 @@ class PipelineSchedule(BaseSchedule):
else:
return model(input_tensor, **batch_data)
def forward_step(self, engine, input_tensor, return_tensors, return_output_label=True, accum_loss=None):
def _forward_step(self, engine, input_tensor, return_tensors, return_output_label=True, accum_loss=None):
"""Forward step for passed-in model. If it is the first stage, the input tensor
is obtained from data_iterator, otherwise the passed-in input_tensor is used.
Returns output tensor. This is a helper function and can be ignored by users.
@@ -186,7 +189,7 @@ class PipelineSchedule(BaseSchedule):
)
return output_tensor
def backward_step(self, engine, input_tensor, output_tensor, output_tensor_grad):
def _backward_step(self, engine, input_tensor, output_tensor, output_tensor_grad):
"""Backward step through the passed-in output tensor. If it is the last stage, the
output_tensor_grad is None, otherwise it is the gradients with respect to stage's output tensor.
Returns the gradients with respect to the input tensor (None if first stage).
@@ -267,11 +270,11 @@ class PipelineSchedule(BaseSchedule):
input_tensor = comm.recv_forward(ft_shape,
dtype=self.dtype,
scatter_gather_tensors=self.scatter_gather_tensors)
output_tensor = self.forward_step(engine,
input_tensor,
return_tensors,
return_output_label=return_output_label,
accum_loss=accum_loss)
output_tensor = self._forward_step(engine,
input_tensor,
return_tensors,
return_output_label=return_output_label,
accum_loss=accum_loss)
if not gpc.is_last_rank(ParallelMode.PIPELINE):
bt_shape = output_tensor.shape
fs_checker = comm.send_tensor_meta(output_tensor, fs_checker)
@@ -295,11 +298,11 @@ class PipelineSchedule(BaseSchedule):
for i in range(num_microbatches_remaining):
last_iteration = (i == (num_microbatches_remaining - 1))
output_tensor = self.forward_step(engine,
input_tensor,
return_tensors,
return_output_label=return_output_label,
accum_loss=accum_loss)
output_tensor = self._forward_step(engine,
input_tensor,
return_tensors,
return_output_label=return_output_label,
accum_loss=accum_loss)
if forward_only:
comm.send_forward(output_tensor, scatter_gather_tensors=self.scatter_gather_tensors)
@@ -323,7 +326,7 @@ class PipelineSchedule(BaseSchedule):
input_tensor = input_tensors.pop(0)
output_tensor = output_tensors.pop(0)
input_tensor_grad = self.backward_step(engine, input_tensor, output_tensor, output_tensor_grad)
input_tensor_grad = self._backward_step(engine, input_tensor, output_tensor, output_tensor_grad)
if last_iteration:
input_tensor = None
@@ -344,7 +347,7 @@ class PipelineSchedule(BaseSchedule):
dtype=self.dtype,
scatter_gather_tensors=self.scatter_gather_tensors)
input_tensor_grad = self.backward_step(engine, input_tensor, output_tensor, output_tensor_grad)
input_tensor_grad = self._backward_step(engine, input_tensor, output_tensor, output_tensor_grad)
comm.send_backward(input_tensor_grad, scatter_gather_tensors=self.scatter_gather_tensors)
@@ -358,8 +361,8 @@ class PipelineSchedule(BaseSchedule):
class InterleavedPipelineSchedule(PipelineSchedule):
def __init__(self,
num_microbatches,
num_model_chunks,
num_microbatches: int,
num_model_chunks: int,
batch_data_process_func: Callable = None,
tensor_shape: Union[torch.Size, List[int], Tuple[int]] = None,
scatter_gather_tensors: bool = False):
@@ -378,6 +381,8 @@ class InterleavedPipelineSchedule(PipelineSchedule):
"""
assert num_microbatches % gpc.get_world_size(ParallelMode.PIPELINE) == 0, \
'num_microbatches must be an integer multiple of pipeline parallel world size'
assert isinstance(num_model_chunks, int) and num_model_chunks > 0, \
f'expected num_model_chunks to be an integer and larger than 0, but got {num_model_chunks}'
super().__init__(num_microbatches,
batch_data_process_func=batch_data_process_func,
tensor_shape=tensor_shape,
@@ -409,13 +414,13 @@ class InterleavedPipelineSchedule(PipelineSchedule):
self.microbatch_offset[model_chunk_id] += self.microbatch_size
return self._move_to_device(data), self._move_to_device(label)
def forward_step(self,
engine,
model_chunk_id,
input_tensor,
return_tensors,
return_output_label=True,
accum_loss=None):
def _forward_step(self,
engine,
model_chunk_id,
input_tensor,
return_tensors,
return_output_label=True,
accum_loss=None):
"""Forward step for passed-in model. If it is the first stage, the input tensor
is obtained from data_iterator, otherwise the passed-in input_tensor is used.
Returns output tensor. This is a helper function and can be ignored by users.
@@ -522,7 +527,7 @@ class InterleavedPipelineSchedule(PipelineSchedule):
model_chunk_id = (num_model_chunks - model_chunk_id - 1)
return model_chunk_id
def forward_step_helper(microbatch_id):
def _forward_step_helper(microbatch_id):
"""Helper method to run forward step with model split into chunks
(run set_virtual_pipeline_model_parallel_rank() before calling
forward_step())."""
@@ -535,12 +540,12 @@ class InterleavedPipelineSchedule(PipelineSchedule):
len(output_tensors[model_chunk_id]):
input_tensors[model_chunk_id].append(None)
input_tensor = input_tensors[model_chunk_id][-1]
output_tensor = self.forward_step(engine,
model_chunk_id,
input_tensor,
return_tensors,
return_output_label=return_output_label,
accum_loss=accum_loss)
output_tensor = self._forward_step(engine,
model_chunk_id,
input_tensor,
return_tensors,
return_output_label=return_output_label,
accum_loss=accum_loss)
output_tensors[model_chunk_id].append(output_tensor)
# if forward-only, no need to save tensors for a backward pass
@@ -550,7 +555,7 @@ class InterleavedPipelineSchedule(PipelineSchedule):
return output_tensor
def backward_step_helper(microbatch_id):
def _backward_step_helper(microbatch_id):
"""Helper method to run backward step with model split into chunks
(run set_virtual_pipeline_model_parallel_rank() before calling
backward_step())."""
@@ -563,7 +568,7 @@ class InterleavedPipelineSchedule(PipelineSchedule):
input_tensor = input_tensors[model_chunk_id].pop(0)
output_tensor = output_tensors[model_chunk_id].pop(0)
output_tensor_grad = output_tensor_grads[model_chunk_id].pop(0)
input_tensor_grad = self.backward_step(engine, input_tensor, output_tensor, output_tensor_grad)
input_tensor_grad = self._backward_step(engine, input_tensor, output_tensor, output_tensor_grad)
return input_tensor_grad
@@ -578,7 +583,7 @@ class InterleavedPipelineSchedule(PipelineSchedule):
for k in range(num_warmup_microbatches):
model_chunk_id = get_model_chunk_id(k, forward=True)
output_tensor = forward_step_helper(k)
output_tensor = _forward_step_helper(k)
if not gpc.is_pipeline_last_stage():
output_tensor_shapes[model_chunk_id] = output_tensor.shape
send_tensor_shape_flags[model_chunk_id] = comm.send_tensor_meta(output_tensor,
@@ -633,11 +638,11 @@ class InterleavedPipelineSchedule(PipelineSchedule):
for k in range(num_microbatches_remaining):
# Forward pass.
forward_k = k + num_warmup_microbatches
output_tensor = forward_step_helper(forward_k)
output_tensor = _forward_step_helper(forward_k)
# Backward pass.
backward_k = k
input_tensor_grad = backward_step_helper(backward_k)
input_tensor_grad = _backward_step_helper(backward_k)
# Send output_tensor and input_tensor_grad, receive input_tensor
# and output_tensor_grad.
@@ -708,7 +713,7 @@ class InterleavedPipelineSchedule(PipelineSchedule):
comm.recv_backward(output_tensor_shapes[num_model_chunks - 1],
scatter_gather_tensors=self.scatter_gather_tensors))
for k in range(num_microbatches_remaining, num_microbatches):
input_tensor_grad = backward_step_helper(k)
input_tensor_grad = _backward_step_helper(k)
next_backward_model_chunk_id = get_model_chunk_id(k + 1, forward=False)
recv_next = True
if gpc.is_pipeline_last_stage(ignore_virtual=True):