add scatter/gather optim for pipeline (#123)

This commit is contained in:
ver217
2022-01-07 13:22:22 +08:00
committed by GitHub
parent 404e6f88ed
commit 293fb40c42
5 changed files with 166 additions and 56 deletions

View File

@@ -6,7 +6,7 @@ import inspect
import torch.cuda
from torch import Tensor
from colossalai.communication import *
import colossalai.communication as comm
from colossalai.context.parallel_mode import ParallelMode
from colossalai.core import global_context as gpc
from colossalai.amp.naive_amp import NaiveAMPModel
@@ -33,16 +33,22 @@ class PipelineSchedule(BaseSchedule):
:type num_microbatches: int
:param batch_data_process_func: The preprocessing function which receives a batch of data, and it will be executed in `load_batch`
:type batch_data_process_func: Callable
:param scatter_gather_tensors: If set to `True`, communication will be reduced over pipeline when using 1D tensor parallelization
:type scatter_gather_tensors: bool
"""
def __init__(self,
num_microbatches,
batch_data_process_func: Callable = None,
tensor_shape: Union[torch.Size, List[int], Tuple[int]] = None):
tensor_shape: Union[torch.Size, List[int], Tuple[int]] = None,
scatter_gather_tensors: bool = False):
super().__init__(batch_data_process_func=batch_data_process_func)
self.num_microbatches = num_microbatches
self.dtype = torch.float
self.tensor_shape = tensor_shape
self.scatter_gather_tensors = False
if gpc.is_initialized(ParallelMode.PARALLEL_1D) and gpc.get_world_size(ParallelMode.PARALLEL_1D) > 1:
self.scatter_gather_tensors = scatter_gather_tensors
def load_batch(self, data_iter):
# Pipeline schedule just puts data in memory
@@ -227,8 +233,9 @@ class PipelineSchedule(BaseSchedule):
# Run warmup forward passes.
for i in range(num_warmup_microbatches):
if not gpc.is_first_rank(ParallelMode.PIPELINE):
ft_shape = recv_tensor_meta(ft_shape)
input_tensor = recv_forward(ft_shape, dtype=self.dtype)
ft_shape = comm.recv_tensor_meta(ft_shape)
input_tensor = comm.recv_forward(ft_shape, dtype=self.dtype,
scatter_gather_tensors=self.scatter_gather_tensors)
output_tensor = self.forward_step(
engine, input_tensor, return_tensors,
return_output_label=return_output_label,
@@ -236,8 +243,8 @@ class PipelineSchedule(BaseSchedule):
)
if not gpc.is_last_rank(ParallelMode.PIPELINE):
bt_shape = output_tensor.shape
fs_checker = send_tensor_meta(output_tensor, fs_checker)
send_forward(output_tensor)
fs_checker = comm.send_tensor_meta(output_tensor, fs_checker)
comm.send_forward(output_tensor, scatter_gather_tensors=self.scatter_gather_tensors)
if not forward_only:
input_tensors.append(input_tensor)
@@ -248,8 +255,9 @@ class PipelineSchedule(BaseSchedule):
# receive this tensor here.
if num_microbatches_remaining > 0:
if not gpc.is_first_rank(ParallelMode.PIPELINE):
ft_shape = recv_tensor_meta(ft_shape)
input_tensor = recv_forward(ft_shape, dtype=self.dtype)
ft_shape = comm.recv_tensor_meta(ft_shape)
input_tensor = comm.recv_forward(ft_shape, dtype=self.dtype,
scatter_gather_tensors=self.scatter_gather_tensors)
# Run 1F1B in steady state.
for i in range(num_microbatches_remaining):
@@ -261,14 +269,15 @@ class PipelineSchedule(BaseSchedule):
accum_loss=accum_loss
)
if forward_only:
send_forward(output_tensor)
comm.send_forward(output_tensor, scatter_gather_tensors=self.scatter_gather_tensors)
if not last_iteration:
input_tensor = recv_forward(ft_shape, dtype=self.dtype)
input_tensor = comm.recv_forward(ft_shape, dtype=self.dtype,
scatter_gather_tensors=self.scatter_gather_tensors)
else:
output_tensor_grad = send_forward_recv_backward(
output_tensor, bt_shape, dtype=self.dtype)
output_tensor_grad = comm.send_forward_recv_backward(
output_tensor, bt_shape, dtype=self.dtype, scatter_gather_tensors=self.scatter_gather_tensors)
# Add input_tensor and output_tensor to end of list.
input_tensors.append(input_tensor)
@@ -287,10 +296,10 @@ class PipelineSchedule(BaseSchedule):
if last_iteration:
input_tensor = None
send_backward(input_tensor_grad)
comm.send_backward(input_tensor_grad, scatter_gather_tensors=self.scatter_gather_tensors)
else:
input_tensor = send_backward_recv_forward(
input_tensor_grad, ft_shape, dtype=self.dtype)
input_tensor = comm.send_backward_recv_forward(
input_tensor_grad, ft_shape, dtype=self.dtype, scatter_gather_tensors=self.scatter_gather_tensors)
# Run cooldown backward passes.
if not forward_only:
@@ -298,7 +307,8 @@ class PipelineSchedule(BaseSchedule):
input_tensor = input_tensors.pop(0)
output_tensor = output_tensors.pop(0)
output_tensor_grad = recv_backward(bt_shape, dtype=self.dtype)
output_tensor_grad = comm.recv_backward(bt_shape, dtype=self.dtype,
scatter_gather_tensors=self.scatter_gather_tensors)
input_tensor_grad = self.backward_step(
engine,
@@ -306,7 +316,7 @@ class PipelineSchedule(BaseSchedule):
output_tensor_grad
)
send_backward(input_tensor_grad)
comm.send_backward(input_tensor_grad, scatter_gather_tensors=self.scatter_gather_tensors)
if len(return_tensors) > 0:
output, label = tuple(map(list, zip(*return_tensors)))
@@ -322,7 +332,8 @@ class InterleavedPipelineSchedule(PipelineSchedule):
num_microbatches,
num_model_chunks,
batch_data_process_func: Callable = None,
tensor_shape: Union[torch.Size, List[int], Tuple[int]] = None):
tensor_shape: Union[torch.Size, List[int], Tuple[int]] = None,
scatter_gather_tensors: bool = False):
"""A helper schedule class for pipeline parallelism running environment.
It uses interleaved 1F1B strategy. Other properties are similar as
:class:`NonPipelineSchedule`.
@@ -333,10 +344,13 @@ class InterleavedPipelineSchedule(PipelineSchedule):
:type num_model_chunks: int
:param batch_data_process_func: The preprocessing function which receives a batch of data, and it will be executed in `load_batch`
:type batch_data_process_func: Callable
:param scatter_gather_tensors: If set to `True`, communication will be reduced over pipeline when using 1D tensor parallelization
:type scatter_gather_tensors: bool
"""
assert num_microbatches % gpc.get_world_size(ParallelMode.PIPELINE) == 0, \
'num_microbatches must be an integer multiple of pipeline parallel world size'
super().__init__(num_microbatches, batch_data_process_func=batch_data_process_func, tensor_shape=tensor_shape)
super().__init__(num_microbatches, batch_data_process_func=batch_data_process_func,
tensor_shape=tensor_shape, scatter_gather_tensors=scatter_gather_tensors)
gpc.set_virtual_pipeline_parallel_size(num_model_chunks)
gpc.set_virtual_pipeline_parallel_rank(0)
self.num_model_chunks = num_model_chunks
@@ -494,15 +508,16 @@ class InterleavedPipelineSchedule(PipelineSchedule):
# Run warmup forward passes.
gpc.set_virtual_pipeline_parallel_rank(0)
if not gpc.is_pipeline_first_stage():
input_tensor_shapes[0] = recv_tensor_meta(input_tensor_shapes[0])
input_tensors[0].append(recv_forward(input_tensor_shapes[0], dtype=self.dtype))
input_tensor_shapes[0] = comm.recv_tensor_meta(input_tensor_shapes[0])
input_tensors[0].append(comm.recv_forward(input_tensor_shapes[0], dtype=self.dtype,
scatter_gather_tensors=self.scatter_gather_tensors))
for k in range(num_warmup_microbatches):
model_chunk_id = get_model_chunk_id(k, forward=True)
output_tensor = forward_step_helper(k)
if not gpc.is_pipeline_last_stage():
output_tensor_shapes[model_chunk_id] = output_tensor.shape
send_tensor_shape_flags[model_chunk_id] = send_tensor_meta(
send_tensor_shape_flags[model_chunk_id] = comm.send_tensor_meta(
output_tensor, send_tensor_shape_flags[model_chunk_id])
# Determine if tensor should be received from previous stage.
next_forward_model_chunk_id = get_model_chunk_id(k+1, forward=True)
@@ -519,7 +534,7 @@ class InterleavedPipelineSchedule(PipelineSchedule):
with switch_virtual_pipeline_parallel_rank(next_forward_model_chunk_id):
if not gpc.is_pipeline_first_stage():
input_tensor_shapes[next_forward_model_chunk_id] = recv_tensor_meta(
input_tensor_shapes[next_forward_model_chunk_id] = comm.recv_tensor_meta(
input_tensor_shapes[next_forward_model_chunk_id])
# Send and receive tensors as appropriate (send tensors computed
# in this iteration; receive tensors for next iteration).
@@ -532,20 +547,22 @@ class InterleavedPipelineSchedule(PipelineSchedule):
recv_next = False
output_shape = output_tensor_shapes[num_model_chunks-1] if recv_next else None
input_tensor, output_tensor_grad = \
send_forward_backward_recv_forward_backward(
comm.send_forward_backward_recv_forward_backward(
output_tensor, input_tensor_grad,
input_shape,
output_shape,
recv_prev=recv_prev, recv_next=recv_next,
dtype=self.dtype)
dtype=self.dtype,
scatter_gather_tensors=self.scatter_gather_tensors)
output_tensor_grads[num_model_chunks-1].append(output_tensor_grad)
else:
input_tensor = \
send_forward_recv_forward(
comm.send_forward_recv_forward(
output_tensor,
input_shape,
recv_prev=recv_prev,
dtype=self.dtype)
dtype=self.dtype,
scatter_gather_tensors=self.scatter_gather_tensors)
input_tensors[next_forward_model_chunk_id].append(input_tensor)
# Run 1F1B in steady state.
@@ -608,12 +625,13 @@ class InterleavedPipelineSchedule(PipelineSchedule):
output_shape = output_tensor_shapes[next_backward_model_chunk_id] if recv_next else None
# Communicate tensors.
input_tensor, output_tensor_grad = \
send_forward_backward_recv_forward_backward(
comm.send_forward_backward_recv_forward_backward(
output_tensor, input_tensor_grad,
input_shape,
output_shape,
recv_prev=recv_prev, recv_next=recv_next,
dtype=self.dtype)
dtype=self.dtype,
scatter_gather_tensors=self.scatter_gather_tensors)
# Put input_tensor and output_tensor_grad in data structures in the
# right location.
@@ -627,7 +645,7 @@ class InterleavedPipelineSchedule(PipelineSchedule):
if not forward_only:
if all_warmup_microbatches:
output_tensor_grads[num_model_chunks-1].append(
recv_backward(output_tensor_shapes[num_model_chunks-1]))
comm.recv_backward(output_tensor_shapes[num_model_chunks-1], scatter_gather_tensors=self.scatter_gather_tensors))
for k in range(num_microbatches_remaining, num_microbatches):
input_tensor_grad = backward_step_helper(k)
next_backward_model_chunk_id = get_model_chunk_id(k+1, forward=False)
@@ -639,11 +657,12 @@ class InterleavedPipelineSchedule(PipelineSchedule):
recv_next = False
output_shape = output_tensor_shapes[next_backward_model_chunk_id] if recv_next else None
output_tensor_grads[next_backward_model_chunk_id].append(
send_backward_recv_backward(
comm.send_backward_recv_backward(
input_tensor_grad,
output_shape,
recv_next=recv_next,
dtype=self.dtype))
dtype=self.dtype,
scatter_gather_tensors=self.scatter_gather_tensors))
if len(return_tensors) > 0:
output, label = tuple(map(list, zip(*return_tensors)))