mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2026-01-26 13:24:33 +00:00
add scatter/gather optim for pipeline (#123)
This commit is contained in:
@@ -6,7 +6,7 @@ import inspect
|
||||
import torch.cuda
|
||||
from torch import Tensor
|
||||
|
||||
from colossalai.communication import *
|
||||
import colossalai.communication as comm
|
||||
from colossalai.context.parallel_mode import ParallelMode
|
||||
from colossalai.core import global_context as gpc
|
||||
from colossalai.amp.naive_amp import NaiveAMPModel
|
||||
@@ -33,16 +33,22 @@ class PipelineSchedule(BaseSchedule):
|
||||
:type num_microbatches: int
|
||||
:param batch_data_process_func: The preprocessing function which receives a batch of data, and it will be executed in `load_batch`
|
||||
:type batch_data_process_func: Callable
|
||||
:param scatter_gather_tensors: If set to `True`, communication will be reduced over pipeline when using 1D tensor parallelization
|
||||
:type scatter_gather_tensors: bool
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
num_microbatches,
|
||||
batch_data_process_func: Callable = None,
|
||||
tensor_shape: Union[torch.Size, List[int], Tuple[int]] = None):
|
||||
tensor_shape: Union[torch.Size, List[int], Tuple[int]] = None,
|
||||
scatter_gather_tensors: bool = False):
|
||||
super().__init__(batch_data_process_func=batch_data_process_func)
|
||||
self.num_microbatches = num_microbatches
|
||||
self.dtype = torch.float
|
||||
self.tensor_shape = tensor_shape
|
||||
self.scatter_gather_tensors = False
|
||||
if gpc.is_initialized(ParallelMode.PARALLEL_1D) and gpc.get_world_size(ParallelMode.PARALLEL_1D) > 1:
|
||||
self.scatter_gather_tensors = scatter_gather_tensors
|
||||
|
||||
def load_batch(self, data_iter):
|
||||
# Pipeline schedule just puts data in memory
|
||||
@@ -227,8 +233,9 @@ class PipelineSchedule(BaseSchedule):
|
||||
# Run warmup forward passes.
|
||||
for i in range(num_warmup_microbatches):
|
||||
if not gpc.is_first_rank(ParallelMode.PIPELINE):
|
||||
ft_shape = recv_tensor_meta(ft_shape)
|
||||
input_tensor = recv_forward(ft_shape, dtype=self.dtype)
|
||||
ft_shape = comm.recv_tensor_meta(ft_shape)
|
||||
input_tensor = comm.recv_forward(ft_shape, dtype=self.dtype,
|
||||
scatter_gather_tensors=self.scatter_gather_tensors)
|
||||
output_tensor = self.forward_step(
|
||||
engine, input_tensor, return_tensors,
|
||||
return_output_label=return_output_label,
|
||||
@@ -236,8 +243,8 @@ class PipelineSchedule(BaseSchedule):
|
||||
)
|
||||
if not gpc.is_last_rank(ParallelMode.PIPELINE):
|
||||
bt_shape = output_tensor.shape
|
||||
fs_checker = send_tensor_meta(output_tensor, fs_checker)
|
||||
send_forward(output_tensor)
|
||||
fs_checker = comm.send_tensor_meta(output_tensor, fs_checker)
|
||||
comm.send_forward(output_tensor, scatter_gather_tensors=self.scatter_gather_tensors)
|
||||
|
||||
if not forward_only:
|
||||
input_tensors.append(input_tensor)
|
||||
@@ -248,8 +255,9 @@ class PipelineSchedule(BaseSchedule):
|
||||
# receive this tensor here.
|
||||
if num_microbatches_remaining > 0:
|
||||
if not gpc.is_first_rank(ParallelMode.PIPELINE):
|
||||
ft_shape = recv_tensor_meta(ft_shape)
|
||||
input_tensor = recv_forward(ft_shape, dtype=self.dtype)
|
||||
ft_shape = comm.recv_tensor_meta(ft_shape)
|
||||
input_tensor = comm.recv_forward(ft_shape, dtype=self.dtype,
|
||||
scatter_gather_tensors=self.scatter_gather_tensors)
|
||||
|
||||
# Run 1F1B in steady state.
|
||||
for i in range(num_microbatches_remaining):
|
||||
@@ -261,14 +269,15 @@ class PipelineSchedule(BaseSchedule):
|
||||
accum_loss=accum_loss
|
||||
)
|
||||
if forward_only:
|
||||
send_forward(output_tensor)
|
||||
comm.send_forward(output_tensor, scatter_gather_tensors=self.scatter_gather_tensors)
|
||||
|
||||
if not last_iteration:
|
||||
input_tensor = recv_forward(ft_shape, dtype=self.dtype)
|
||||
input_tensor = comm.recv_forward(ft_shape, dtype=self.dtype,
|
||||
scatter_gather_tensors=self.scatter_gather_tensors)
|
||||
|
||||
else:
|
||||
output_tensor_grad = send_forward_recv_backward(
|
||||
output_tensor, bt_shape, dtype=self.dtype)
|
||||
output_tensor_grad = comm.send_forward_recv_backward(
|
||||
output_tensor, bt_shape, dtype=self.dtype, scatter_gather_tensors=self.scatter_gather_tensors)
|
||||
|
||||
# Add input_tensor and output_tensor to end of list.
|
||||
input_tensors.append(input_tensor)
|
||||
@@ -287,10 +296,10 @@ class PipelineSchedule(BaseSchedule):
|
||||
|
||||
if last_iteration:
|
||||
input_tensor = None
|
||||
send_backward(input_tensor_grad)
|
||||
comm.send_backward(input_tensor_grad, scatter_gather_tensors=self.scatter_gather_tensors)
|
||||
else:
|
||||
input_tensor = send_backward_recv_forward(
|
||||
input_tensor_grad, ft_shape, dtype=self.dtype)
|
||||
input_tensor = comm.send_backward_recv_forward(
|
||||
input_tensor_grad, ft_shape, dtype=self.dtype, scatter_gather_tensors=self.scatter_gather_tensors)
|
||||
|
||||
# Run cooldown backward passes.
|
||||
if not forward_only:
|
||||
@@ -298,7 +307,8 @@ class PipelineSchedule(BaseSchedule):
|
||||
input_tensor = input_tensors.pop(0)
|
||||
output_tensor = output_tensors.pop(0)
|
||||
|
||||
output_tensor_grad = recv_backward(bt_shape, dtype=self.dtype)
|
||||
output_tensor_grad = comm.recv_backward(bt_shape, dtype=self.dtype,
|
||||
scatter_gather_tensors=self.scatter_gather_tensors)
|
||||
|
||||
input_tensor_grad = self.backward_step(
|
||||
engine,
|
||||
@@ -306,7 +316,7 @@ class PipelineSchedule(BaseSchedule):
|
||||
output_tensor_grad
|
||||
)
|
||||
|
||||
send_backward(input_tensor_grad)
|
||||
comm.send_backward(input_tensor_grad, scatter_gather_tensors=self.scatter_gather_tensors)
|
||||
|
||||
if len(return_tensors) > 0:
|
||||
output, label = tuple(map(list, zip(*return_tensors)))
|
||||
@@ -322,7 +332,8 @@ class InterleavedPipelineSchedule(PipelineSchedule):
|
||||
num_microbatches,
|
||||
num_model_chunks,
|
||||
batch_data_process_func: Callable = None,
|
||||
tensor_shape: Union[torch.Size, List[int], Tuple[int]] = None):
|
||||
tensor_shape: Union[torch.Size, List[int], Tuple[int]] = None,
|
||||
scatter_gather_tensors: bool = False):
|
||||
"""A helper schedule class for pipeline parallelism running environment.
|
||||
It uses interleaved 1F1B strategy. Other properties are similar as
|
||||
:class:`NonPipelineSchedule`.
|
||||
@@ -333,10 +344,13 @@ class InterleavedPipelineSchedule(PipelineSchedule):
|
||||
:type num_model_chunks: int
|
||||
:param batch_data_process_func: The preprocessing function which receives a batch of data, and it will be executed in `load_batch`
|
||||
:type batch_data_process_func: Callable
|
||||
:param scatter_gather_tensors: If set to `True`, communication will be reduced over pipeline when using 1D tensor parallelization
|
||||
:type scatter_gather_tensors: bool
|
||||
"""
|
||||
assert num_microbatches % gpc.get_world_size(ParallelMode.PIPELINE) == 0, \
|
||||
'num_microbatches must be an integer multiple of pipeline parallel world size'
|
||||
super().__init__(num_microbatches, batch_data_process_func=batch_data_process_func, tensor_shape=tensor_shape)
|
||||
super().__init__(num_microbatches, batch_data_process_func=batch_data_process_func,
|
||||
tensor_shape=tensor_shape, scatter_gather_tensors=scatter_gather_tensors)
|
||||
gpc.set_virtual_pipeline_parallel_size(num_model_chunks)
|
||||
gpc.set_virtual_pipeline_parallel_rank(0)
|
||||
self.num_model_chunks = num_model_chunks
|
||||
@@ -494,15 +508,16 @@ class InterleavedPipelineSchedule(PipelineSchedule):
|
||||
# Run warmup forward passes.
|
||||
gpc.set_virtual_pipeline_parallel_rank(0)
|
||||
if not gpc.is_pipeline_first_stage():
|
||||
input_tensor_shapes[0] = recv_tensor_meta(input_tensor_shapes[0])
|
||||
input_tensors[0].append(recv_forward(input_tensor_shapes[0], dtype=self.dtype))
|
||||
input_tensor_shapes[0] = comm.recv_tensor_meta(input_tensor_shapes[0])
|
||||
input_tensors[0].append(comm.recv_forward(input_tensor_shapes[0], dtype=self.dtype,
|
||||
scatter_gather_tensors=self.scatter_gather_tensors))
|
||||
|
||||
for k in range(num_warmup_microbatches):
|
||||
model_chunk_id = get_model_chunk_id(k, forward=True)
|
||||
output_tensor = forward_step_helper(k)
|
||||
if not gpc.is_pipeline_last_stage():
|
||||
output_tensor_shapes[model_chunk_id] = output_tensor.shape
|
||||
send_tensor_shape_flags[model_chunk_id] = send_tensor_meta(
|
||||
send_tensor_shape_flags[model_chunk_id] = comm.send_tensor_meta(
|
||||
output_tensor, send_tensor_shape_flags[model_chunk_id])
|
||||
# Determine if tensor should be received from previous stage.
|
||||
next_forward_model_chunk_id = get_model_chunk_id(k+1, forward=True)
|
||||
@@ -519,7 +534,7 @@ class InterleavedPipelineSchedule(PipelineSchedule):
|
||||
|
||||
with switch_virtual_pipeline_parallel_rank(next_forward_model_chunk_id):
|
||||
if not gpc.is_pipeline_first_stage():
|
||||
input_tensor_shapes[next_forward_model_chunk_id] = recv_tensor_meta(
|
||||
input_tensor_shapes[next_forward_model_chunk_id] = comm.recv_tensor_meta(
|
||||
input_tensor_shapes[next_forward_model_chunk_id])
|
||||
# Send and receive tensors as appropriate (send tensors computed
|
||||
# in this iteration; receive tensors for next iteration).
|
||||
@@ -532,20 +547,22 @@ class InterleavedPipelineSchedule(PipelineSchedule):
|
||||
recv_next = False
|
||||
output_shape = output_tensor_shapes[num_model_chunks-1] if recv_next else None
|
||||
input_tensor, output_tensor_grad = \
|
||||
send_forward_backward_recv_forward_backward(
|
||||
comm.send_forward_backward_recv_forward_backward(
|
||||
output_tensor, input_tensor_grad,
|
||||
input_shape,
|
||||
output_shape,
|
||||
recv_prev=recv_prev, recv_next=recv_next,
|
||||
dtype=self.dtype)
|
||||
dtype=self.dtype,
|
||||
scatter_gather_tensors=self.scatter_gather_tensors)
|
||||
output_tensor_grads[num_model_chunks-1].append(output_tensor_grad)
|
||||
else:
|
||||
input_tensor = \
|
||||
send_forward_recv_forward(
|
||||
comm.send_forward_recv_forward(
|
||||
output_tensor,
|
||||
input_shape,
|
||||
recv_prev=recv_prev,
|
||||
dtype=self.dtype)
|
||||
dtype=self.dtype,
|
||||
scatter_gather_tensors=self.scatter_gather_tensors)
|
||||
input_tensors[next_forward_model_chunk_id].append(input_tensor)
|
||||
|
||||
# Run 1F1B in steady state.
|
||||
@@ -608,12 +625,13 @@ class InterleavedPipelineSchedule(PipelineSchedule):
|
||||
output_shape = output_tensor_shapes[next_backward_model_chunk_id] if recv_next else None
|
||||
# Communicate tensors.
|
||||
input_tensor, output_tensor_grad = \
|
||||
send_forward_backward_recv_forward_backward(
|
||||
comm.send_forward_backward_recv_forward_backward(
|
||||
output_tensor, input_tensor_grad,
|
||||
input_shape,
|
||||
output_shape,
|
||||
recv_prev=recv_prev, recv_next=recv_next,
|
||||
dtype=self.dtype)
|
||||
dtype=self.dtype,
|
||||
scatter_gather_tensors=self.scatter_gather_tensors)
|
||||
|
||||
# Put input_tensor and output_tensor_grad in data structures in the
|
||||
# right location.
|
||||
@@ -627,7 +645,7 @@ class InterleavedPipelineSchedule(PipelineSchedule):
|
||||
if not forward_only:
|
||||
if all_warmup_microbatches:
|
||||
output_tensor_grads[num_model_chunks-1].append(
|
||||
recv_backward(output_tensor_shapes[num_model_chunks-1]))
|
||||
comm.recv_backward(output_tensor_shapes[num_model_chunks-1], scatter_gather_tensors=self.scatter_gather_tensors))
|
||||
for k in range(num_microbatches_remaining, num_microbatches):
|
||||
input_tensor_grad = backward_step_helper(k)
|
||||
next_backward_model_chunk_id = get_model_chunk_id(k+1, forward=False)
|
||||
@@ -639,11 +657,12 @@ class InterleavedPipelineSchedule(PipelineSchedule):
|
||||
recv_next = False
|
||||
output_shape = output_tensor_shapes[next_backward_model_chunk_id] if recv_next else None
|
||||
output_tensor_grads[next_backward_model_chunk_id].append(
|
||||
send_backward_recv_backward(
|
||||
comm.send_backward_recv_backward(
|
||||
input_tensor_grad,
|
||||
output_shape,
|
||||
recv_next=recv_next,
|
||||
dtype=self.dtype))
|
||||
dtype=self.dtype,
|
||||
scatter_gather_tensors=self.scatter_gather_tensors))
|
||||
|
||||
if len(return_tensors) > 0:
|
||||
output, label = tuple(map(list, zip(*return_tensors)))
|
||||
|
||||
Reference in New Issue
Block a user