[zero] ZeRO supports pipeline parallel (#477)

This commit is contained in:
ver217 2022-03-21 16:55:37 +08:00 committed by GitHub
parent 7f5e4592eb
commit 8d3250d74b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 113 additions and 95 deletions

View File

@ -1,12 +1,14 @@
#!/usr/bin/env python #!/usr/bin/env python
import torch.distributed as dist from collections import defaultdict
from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
import torch
import torch.distributed as dist
from colossalai.core import global_context as gpc from colossalai.core import global_context as gpc
from colossalai.registry import GRADIENT_HANDLER from colossalai.registry import GRADIENT_HANDLER
from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
from ._base_gradient_handler import BaseGradientHandler from ._base_gradient_handler import BaseGradientHandler
from collections import defaultdict
@GRADIENT_HANDLER.register_module @GRADIENT_HANDLER.register_module
@ -35,7 +37,7 @@ class PipelineSharedModuleGradientHandler(BaseGradientHandler):
for group, group_buckets in buckets.items(): for group, group_buckets in buckets.items():
for tp, bucket in group_buckets.items(): for tp, bucket in group_buckets.items():
grads = [param.grad.data for param in bucket] grads = [param.grad.data for param in bucket]
coalesced = _flatten_dense_tensors(grads) coalesced = _flatten_dense_tensors(grads).to(torch.cuda.current_device())
dist.all_reduce(coalesced, op=dist.ReduceOp.SUM, group=group) dist.all_reduce(coalesced, op=dist.ReduceOp.SUM, group=group)
for buf, synced in zip(grads, _unflatten_dense_tensors(coalesced, grads)): for buf, synced in zip(grads, _unflatten_dense_tensors(coalesced, grads)):
buf.copy_(synced) buf.copy_(synced)

View File

@ -12,7 +12,8 @@ from colossalai.core import global_context as gpc
from colossalai.logging import get_dist_logger from colossalai.logging import get_dist_logger
from colossalai.utils import switch_virtual_pipeline_parallel_rank from colossalai.utils import switch_virtual_pipeline_parallel_rank
from colossalai.utils.cuda import get_current_device from colossalai.utils.cuda import get_current_device
from colossalai.zero import ShardedOptimizer, ShardedModel from colossalai.zero import ShardedModel, ShardedOptimizer
from colossalai.zero.sharded_model import ShardedModelV2
from ._base_schedule import BaseSchedule from ._base_schedule import BaseSchedule
@ -79,8 +80,8 @@ class PipelineSchedule(BaseSchedule):
def _get_data_slice(self, data, offset): def _get_data_slice(self, data, offset):
if isinstance(data, torch.Tensor): if isinstance(data, torch.Tensor):
return data[offset: offset + self.microbatch_size] return data[offset:offset + self.microbatch_size]
else: elif isinstance(data, dict):
return {k: v[offset:offset + self.microbatch_size] for k, v in data.items()} return {k: v[offset:offset + self.microbatch_size] for k, v in data.items()}
def load_micro_batch(self): def load_micro_batch(self):
@ -92,11 +93,9 @@ class PipelineSchedule(BaseSchedule):
def pre_processing(self, engine): def pre_processing(self, engine):
# TODO: remove this after testing new zero with pipeline parallelism # TODO: remove this after testing new zero with pipeline parallelism
if isinstance(engine.optimizer, ShardedOptimizer) or isinstance(engine.model, ShardedModel): if isinstance(engine.optimizer, ShardedOptimizer) or isinstance(engine.model, ShardedModel):
raise TypeError( raise TypeError("Pipeline schedule is currently not compatible with ZeRO")
"Pipeline schedule is currently not compatible with ZeRO"
)
model = engine.model model = engine.model
if isinstance(model, NaiveAMPModel): if isinstance(model, (NaiveAMPModel, ShardedModelV2)):
self.dtype = torch.half self.dtype = torch.half
model = model.model model = model.model
sig = inspect.signature(model.forward) sig = inspect.signature(model.forward)
@ -107,6 +106,8 @@ class PipelineSchedule(BaseSchedule):
def _call_engine(model, input_tensor, batch_data): def _call_engine(model, input_tensor, batch_data):
if isinstance(model, NaiveAMPModel): if isinstance(model, NaiveAMPModel):
sig = inspect.signature(model.model.forward) sig = inspect.signature(model.model.forward)
elif isinstance(model, ShardedModelV2):
sig = inspect.signature(model.module.forward)
else: else:
sig = inspect.signature(model.forward) sig = inspect.signature(model.forward)
if isinstance(batch_data, torch.Tensor): if isinstance(batch_data, torch.Tensor):
@ -162,9 +163,11 @@ class PipelineSchedule(BaseSchedule):
return output_tensor return output_tensor
else: else:
assert isinstance( assert isinstance(
output_tensor, torch.Tensor), 'Output of model using pipeline parallelism must be a tensor (except the last stage).' output_tensor,
torch.Tensor), 'Output of model using pipeline parallelism must be a tensor (except the last stage).'
self._logger.debug( self._logger.debug(
f'Global rank {gpc.get_global_rank()}, pipeline rank {gpc.get_local_rank(ParallelMode.PIPELINE)} forward output tensor {output_tensor.shape}, dtype {output_tensor.dtype}') f'Global rank {gpc.get_global_rank()}, pipeline rank {gpc.get_local_rank(ParallelMode.PIPELINE)} forward output tensor {output_tensor.shape}, dtype {output_tensor.dtype}'
)
return output_tensor return output_tensor
def backward_step(self, engine, input_tensor, output_tensor, output_tensor_grad): def backward_step(self, engine, input_tensor, output_tensor, output_tensor_grad):
@ -203,12 +206,7 @@ class PipelineSchedule(BaseSchedule):
return input_tensor_grad return input_tensor_grad
def forward_backward_step(self, def forward_backward_step(self, engine, data_iter, forward_only=False, return_loss=True, return_output_label=True):
engine,
data_iter,
forward_only=False,
return_loss=True,
return_output_label=True):
"""Runs non-interleaved 1F1B schedule, with communication between pipeline stages. """Runs non-interleaved 1F1B schedule, with communication between pipeline stages.
Returns a tuple with losses if the last stage, an empty tuple otherwise. Returns a tuple with losses if the last stage, an empty tuple otherwise.
@ -231,10 +229,9 @@ class PipelineSchedule(BaseSchedule):
'The argument \'return_loss\' has to be True when \'forward_only\' is False, but got False.' 'The argument \'return_loss\' has to be True when \'forward_only\' is False, but got False.'
self.load_batch(data_iter) self.load_batch(data_iter)
num_warmup_microbatches = \ num_warmup_microbatches = \
(gpc.get_world_size(ParallelMode.PIPELINE) - (gpc.get_world_size(ParallelMode.PIPELINE)
gpc.get_local_rank(ParallelMode.PIPELINE) - 1) - gpc.get_local_rank(ParallelMode.PIPELINE) - 1)
num_warmup_microbatches = min(num_warmup_microbatches, num_warmup_microbatches = min(num_warmup_microbatches, self.num_microbatches)
self.num_microbatches)
num_microbatches_remaining = self.num_microbatches - num_warmup_microbatches num_microbatches_remaining = self.num_microbatches - num_warmup_microbatches
# Input, output tensors only need to be saved when doing backward passes # Input, output tensors only need to be saved when doing backward passes
@ -257,13 +254,14 @@ class PipelineSchedule(BaseSchedule):
for i in range(num_warmup_microbatches): for i in range(num_warmup_microbatches):
if not gpc.is_first_rank(ParallelMode.PIPELINE): if not gpc.is_first_rank(ParallelMode.PIPELINE):
ft_shape = comm.recv_tensor_meta(ft_shape) ft_shape = comm.recv_tensor_meta(ft_shape)
input_tensor = comm.recv_forward(ft_shape, dtype=self.dtype, input_tensor = comm.recv_forward(ft_shape,
dtype=self.dtype,
scatter_gather_tensors=self.scatter_gather_tensors) scatter_gather_tensors=self.scatter_gather_tensors)
output_tensor = self.forward_step( output_tensor = self.forward_step(engine,
engine, input_tensor, return_tensors, input_tensor,
return_output_label=return_output_label, return_tensors,
accum_loss=accum_loss return_output_label=return_output_label,
) accum_loss=accum_loss)
if not gpc.is_last_rank(ParallelMode.PIPELINE): if not gpc.is_last_rank(ParallelMode.PIPELINE):
bt_shape = output_tensor.shape bt_shape = output_tensor.shape
fs_checker = comm.send_tensor_meta(output_tensor, fs_checker) fs_checker = comm.send_tensor_meta(output_tensor, fs_checker)
@ -279,28 +277,32 @@ class PipelineSchedule(BaseSchedule):
if num_microbatches_remaining > 0: if num_microbatches_remaining > 0:
if not gpc.is_first_rank(ParallelMode.PIPELINE): if not gpc.is_first_rank(ParallelMode.PIPELINE):
ft_shape = comm.recv_tensor_meta(ft_shape) ft_shape = comm.recv_tensor_meta(ft_shape)
input_tensor = comm.recv_forward(ft_shape, dtype=self.dtype, input_tensor = comm.recv_forward(ft_shape,
dtype=self.dtype,
scatter_gather_tensors=self.scatter_gather_tensors) scatter_gather_tensors=self.scatter_gather_tensors)
# Run 1F1B in steady state. # Run 1F1B in steady state.
for i in range(num_microbatches_remaining): for i in range(num_microbatches_remaining):
last_iteration = (i == (num_microbatches_remaining - 1)) last_iteration = (i == (num_microbatches_remaining - 1))
output_tensor = self.forward_step( output_tensor = self.forward_step(engine,
engine, input_tensor, return_tensors, input_tensor,
return_output_label=return_output_label, return_tensors,
accum_loss=accum_loss return_output_label=return_output_label,
) accum_loss=accum_loss)
if forward_only: if forward_only:
comm.send_forward(output_tensor, scatter_gather_tensors=self.scatter_gather_tensors) comm.send_forward(output_tensor, scatter_gather_tensors=self.scatter_gather_tensors)
if not last_iteration: if not last_iteration:
input_tensor = comm.recv_forward(ft_shape, dtype=self.dtype, input_tensor = comm.recv_forward(ft_shape,
dtype=self.dtype,
scatter_gather_tensors=self.scatter_gather_tensors) scatter_gather_tensors=self.scatter_gather_tensors)
else: else:
output_tensor_grad = comm.send_forward_recv_backward( output_tensor_grad = comm.send_forward_recv_backward(output_tensor,
output_tensor, bt_shape, dtype=self.dtype, scatter_gather_tensors=self.scatter_gather_tensors) bt_shape,
dtype=self.dtype,
scatter_gather_tensors=self.scatter_gather_tensors)
# Add input_tensor and output_tensor to end of list. # Add input_tensor and output_tensor to end of list.
input_tensors.append(input_tensor) input_tensors.append(input_tensor)
@ -311,18 +313,16 @@ class PipelineSchedule(BaseSchedule):
input_tensor = input_tensors.pop(0) input_tensor = input_tensors.pop(0)
output_tensor = output_tensors.pop(0) output_tensor = output_tensors.pop(0)
input_tensor_grad = self.backward_step( input_tensor_grad = self.backward_step(engine, input_tensor, output_tensor, output_tensor_grad)
engine,
input_tensor, output_tensor,
output_tensor_grad
)
if last_iteration: if last_iteration:
input_tensor = None input_tensor = None
comm.send_backward(input_tensor_grad, scatter_gather_tensors=self.scatter_gather_tensors) comm.send_backward(input_tensor_grad, scatter_gather_tensors=self.scatter_gather_tensors)
else: else:
input_tensor = comm.send_backward_recv_forward( input_tensor = comm.send_backward_recv_forward(input_tensor_grad,
input_tensor_grad, ft_shape, dtype=self.dtype, scatter_gather_tensors=self.scatter_gather_tensors) ft_shape,
dtype=self.dtype,
scatter_gather_tensors=self.scatter_gather_tensors)
# Run cooldown backward passes. # Run cooldown backward passes.
if not forward_only: if not forward_only:
@ -330,14 +330,11 @@ class PipelineSchedule(BaseSchedule):
input_tensor = input_tensors.pop(0) input_tensor = input_tensors.pop(0)
output_tensor = output_tensors.pop(0) output_tensor = output_tensors.pop(0)
output_tensor_grad = comm.recv_backward(bt_shape, dtype=self.dtype, output_tensor_grad = comm.recv_backward(bt_shape,
dtype=self.dtype,
scatter_gather_tensors=self.scatter_gather_tensors) scatter_gather_tensors=self.scatter_gather_tensors)
input_tensor_grad = self.backward_step( input_tensor_grad = self.backward_step(engine, input_tensor, output_tensor, output_tensor_grad)
engine,
input_tensor, output_tensor,
output_tensor_grad
)
comm.send_backward(input_tensor_grad, scatter_gather_tensors=self.scatter_gather_tensors) comm.send_backward(input_tensor_grad, scatter_gather_tensors=self.scatter_gather_tensors)
@ -349,6 +346,7 @@ class PipelineSchedule(BaseSchedule):
class InterleavedPipelineSchedule(PipelineSchedule): class InterleavedPipelineSchedule(PipelineSchedule):
def __init__(self, def __init__(self,
num_microbatches, num_microbatches,
num_model_chunks, num_model_chunks,
@ -372,21 +370,19 @@ class InterleavedPipelineSchedule(PipelineSchedule):
""" """
assert num_microbatches % gpc.get_world_size(ParallelMode.PIPELINE) == 0, \ assert num_microbatches % gpc.get_world_size(ParallelMode.PIPELINE) == 0, \
'num_microbatches must be an integer multiple of pipeline parallel world size' 'num_microbatches must be an integer multiple of pipeline parallel world size'
super().__init__(num_microbatches, batch_data_process_func=batch_data_process_func, super().__init__(num_microbatches,
tensor_shape=tensor_shape, scatter_gather_tensors=scatter_gather_tensors) batch_data_process_func=batch_data_process_func,
tensor_shape=tensor_shape,
scatter_gather_tensors=scatter_gather_tensors)
gpc.set_virtual_pipeline_parallel_size(num_model_chunks) gpc.set_virtual_pipeline_parallel_size(num_model_chunks)
gpc.set_virtual_pipeline_parallel_rank(0) gpc.set_virtual_pipeline_parallel_rank(0)
self.num_model_chunks = num_model_chunks self.num_model_chunks = num_model_chunks
def pre_processing(self, engine): def pre_processing(self, engine):
if isinstance(engine.optimizer, (ZeroRedundancyOptimizer_Level_2, ZeroRedundancyOptimizer_Level_3)): if isinstance(engine.model, ShardedModelV2):
raise TypeError( self.dtype = torch.half
"Pipeline schedule is currently not compatible with ZeRO Level 2 and Level 3" elif isinstance(engine.model[0], NaiveAMPModel):
)
if isinstance(engine.model[0], NaiveAMPModel):
self.dtype = torch.half self.dtype = torch.half
for model in engine.model: for model in engine.model:
if isinstance(model, NaiveAMPModel): if isinstance(model, NaiveAMPModel):
model = model.model model = model.model
@ -405,7 +401,13 @@ class InterleavedPipelineSchedule(PipelineSchedule):
self.microbatch_offset[model_chunk_id] += self.microbatch_size self.microbatch_offset[model_chunk_id] += self.microbatch_size
return self._move_to_device(data), self._move_to_device(label) return self._move_to_device(data), self._move_to_device(label)
def forward_step(self, engine, model_chunk_id, input_tensor, return_tensors, return_output_label=True, accum_loss=None): def forward_step(self,
engine,
model_chunk_id,
input_tensor,
return_tensors,
return_output_label=True,
accum_loss=None):
"""Forward step for passed-in model. If it is the first stage, the input tensor """Forward step for passed-in model. If it is the first stage, the input tensor
is obtained from data_iterator, otherwise the passed-in input_tensor is used. is obtained from data_iterator, otherwise the passed-in input_tensor is used.
Returns output tensor. This is a helper function and can be ignored by users. Returns output tensor. This is a helper function and can be ignored by users.
@ -425,9 +427,11 @@ class InterleavedPipelineSchedule(PipelineSchedule):
return output_tensor return output_tensor
else: else:
assert isinstance( assert isinstance(
output_tensor, torch.Tensor), 'Output of model using pipeline parallelism must be a tensor (except the last stage).' output_tensor,
torch.Tensor), 'Output of model using pipeline parallelism must be a tensor (except the last stage).'
self._logger.debug( self._logger.debug(
f'Global rank {gpc.get_global_rank()}, pipeline rank {gpc.get_local_rank(ParallelMode.PIPELINE)} forward output tensor {output_tensor.shape}, dtype {output_tensor.dtype}') f'Global rank {gpc.get_global_rank()}, pipeline rank {gpc.get_local_rank(ParallelMode.PIPELINE)} forward output tensor {output_tensor.shape}, dtype {output_tensor.dtype}'
)
return output_tensor return output_tensor
def forward_backward_step(self, engine, data_iter, forward_only=False, return_loss=True, return_output_label=True): def forward_backward_step(self, engine, data_iter, forward_only=False, return_loss=True, return_output_label=True):
@ -488,10 +492,8 @@ class InterleavedPipelineSchedule(PipelineSchedule):
else: else:
num_warmup_microbatches = \ num_warmup_microbatches = \
(pipeline_parallel_size - pipeline_parallel_rank - 1) * 2 (pipeline_parallel_size - pipeline_parallel_rank - 1) * 2
num_warmup_microbatches += ( num_warmup_microbatches += (num_model_chunks - 1) * pipeline_parallel_size
num_model_chunks - 1) * pipeline_parallel_size num_warmup_microbatches = min(num_warmup_microbatches, num_microbatches)
num_warmup_microbatches = min(num_warmup_microbatches,
num_microbatches)
num_microbatches_remaining = \ num_microbatches_remaining = \
num_microbatches - num_warmup_microbatches num_microbatches - num_warmup_microbatches
@ -516,8 +518,12 @@ class InterleavedPipelineSchedule(PipelineSchedule):
len(output_tensors[model_chunk_id]): len(output_tensors[model_chunk_id]):
input_tensors[model_chunk_id].append(None) input_tensors[model_chunk_id].append(None)
input_tensor = input_tensors[model_chunk_id][-1] input_tensor = input_tensors[model_chunk_id][-1]
output_tensor = self.forward_step(engine, model_chunk_id, input_tensor, output_tensor = self.forward_step(engine,
return_tensors, return_output_label=return_output_label, accum_loss=accum_loss) model_chunk_id,
input_tensor,
return_tensors,
return_output_label=return_output_label,
accum_loss=accum_loss)
output_tensors[model_chunk_id].append(output_tensor) output_tensors[model_chunk_id].append(output_tensor)
# if forward-only, no need to save tensors for a backward pass # if forward-only, no need to save tensors for a backward pass
@ -548,18 +554,20 @@ class InterleavedPipelineSchedule(PipelineSchedule):
gpc.set_virtual_pipeline_parallel_rank(0) gpc.set_virtual_pipeline_parallel_rank(0)
if not gpc.is_pipeline_first_stage(): if not gpc.is_pipeline_first_stage():
input_tensor_shapes[0] = comm.recv_tensor_meta(input_tensor_shapes[0]) input_tensor_shapes[0] = comm.recv_tensor_meta(input_tensor_shapes[0])
input_tensors[0].append(comm.recv_forward(input_tensor_shapes[0], dtype=self.dtype, input_tensors[0].append(
scatter_gather_tensors=self.scatter_gather_tensors)) comm.recv_forward(input_tensor_shapes[0],
dtype=self.dtype,
scatter_gather_tensors=self.scatter_gather_tensors))
for k in range(num_warmup_microbatches): for k in range(num_warmup_microbatches):
model_chunk_id = get_model_chunk_id(k, forward=True) model_chunk_id = get_model_chunk_id(k, forward=True)
output_tensor = forward_step_helper(k) output_tensor = forward_step_helper(k)
if not gpc.is_pipeline_last_stage(): if not gpc.is_pipeline_last_stage():
output_tensor_shapes[model_chunk_id] = output_tensor.shape output_tensor_shapes[model_chunk_id] = output_tensor.shape
send_tensor_shape_flags[model_chunk_id] = comm.send_tensor_meta( send_tensor_shape_flags[model_chunk_id] = comm.send_tensor_meta(output_tensor,
output_tensor, send_tensor_shape_flags[model_chunk_id]) send_tensor_shape_flags[model_chunk_id])
# Determine if tensor should be received from previous stage. # Determine if tensor should be received from previous stage.
next_forward_model_chunk_id = get_model_chunk_id(k+1, forward=True) next_forward_model_chunk_id = get_model_chunk_id(k + 1, forward=True)
recv_prev = True recv_prev = True
if gpc.is_pipeline_first_stage(ignore_virtual=True): if gpc.is_pipeline_first_stage(ignore_virtual=True):
if next_forward_model_chunk_id == 0: if next_forward_model_chunk_id == 0:
@ -584,7 +592,7 @@ class InterleavedPipelineSchedule(PipelineSchedule):
recv_next = True recv_next = True
if gpc.is_pipeline_last_stage(ignore_virtual=True): if gpc.is_pipeline_last_stage(ignore_virtual=True):
recv_next = False recv_next = False
output_shape = output_tensor_shapes[num_model_chunks-1] if recv_next else None output_shape = output_tensor_shapes[num_model_chunks - 1] if recv_next else None
input_tensor, output_tensor_grad = \ input_tensor, output_tensor_grad = \
comm.send_forward_backward_recv_forward_backward( comm.send_forward_backward_recv_forward_backward(
output_tensor, input_tensor_grad, output_tensor, input_tensor_grad,
@ -593,7 +601,7 @@ class InterleavedPipelineSchedule(PipelineSchedule):
recv_prev=recv_prev, recv_next=recv_next, recv_prev=recv_prev, recv_next=recv_next,
dtype=self.dtype, dtype=self.dtype,
scatter_gather_tensors=self.scatter_gather_tensors) scatter_gather_tensors=self.scatter_gather_tensors)
output_tensor_grads[num_model_chunks-1].append(output_tensor_grad) output_tensor_grads[num_model_chunks - 1].append(output_tensor_grad)
else: else:
input_tensor = \ input_tensor = \
comm.send_forward_recv_forward( comm.send_forward_recv_forward(
@ -634,26 +642,23 @@ class InterleavedPipelineSchedule(PipelineSchedule):
recv_prev = True recv_prev = True
if gpc.is_pipeline_first_stage(ignore_virtual=True): if gpc.is_pipeline_first_stage(ignore_virtual=True):
# First stage is ahead of last stage by (pipeline_parallel_size - 1). # First stage is ahead of last stage by (pipeline_parallel_size - 1).
next_forward_model_chunk_id = get_model_chunk_id( next_forward_model_chunk_id = get_model_chunk_id(forward_k - (pipeline_parallel_size - 1), forward=True)
forward_k - (pipeline_parallel_size - 1), forward=True)
if next_forward_model_chunk_id == (num_model_chunks - 1): if next_forward_model_chunk_id == (num_model_chunks - 1):
recv_prev = False recv_prev = False
next_forward_model_chunk_id += 1 next_forward_model_chunk_id += 1
else: else:
next_forward_model_chunk_id = get_model_chunk_id(forward_k + 1, next_forward_model_chunk_id = get_model_chunk_id(forward_k + 1, forward=True)
forward=True)
recv_next = True recv_next = True
if gpc.is_pipeline_last_stage(ignore_virtual=True): if gpc.is_pipeline_last_stage(ignore_virtual=True):
# Last stage is ahead of first stage by (pipeline_parallel_size - 1). # Last stage is ahead of first stage by (pipeline_parallel_size - 1).
next_backward_model_chunk_id = get_model_chunk_id( next_backward_model_chunk_id = get_model_chunk_id(backward_k - (pipeline_parallel_size - 1),
backward_k - (pipeline_parallel_size - 1), forward=False) forward=False)
if next_backward_model_chunk_id == 0: if next_backward_model_chunk_id == 0:
recv_next = False recv_next = False
next_backward_model_chunk_id -= 1 next_backward_model_chunk_id -= 1
else: else:
next_backward_model_chunk_id = get_model_chunk_id(backward_k + 1, next_backward_model_chunk_id = get_model_chunk_id(backward_k + 1, forward=False)
forward=False)
# If last iteration, don't receive; we already received one extra # If last iteration, don't receive; we already received one extra
# before the start of the for loop. # before the start of the for loop.
@ -677,17 +682,17 @@ class InterleavedPipelineSchedule(PipelineSchedule):
if recv_prev: if recv_prev:
input_tensors[next_forward_model_chunk_id].append(input_tensor) input_tensors[next_forward_model_chunk_id].append(input_tensor)
if recv_next: if recv_next:
output_tensor_grads[next_backward_model_chunk_id].append( output_tensor_grads[next_backward_model_chunk_id].append(output_tensor_grad)
output_tensor_grad)
# Run cooldown backward passes (flush out pipeline). # Run cooldown backward passes (flush out pipeline).
if not forward_only: if not forward_only:
if all_warmup_microbatches: if all_warmup_microbatches:
output_tensor_grads[num_model_chunks-1].append( output_tensor_grads[num_model_chunks - 1].append(
comm.recv_backward(output_tensor_shapes[num_model_chunks-1], scatter_gather_tensors=self.scatter_gather_tensors)) comm.recv_backward(output_tensor_shapes[num_model_chunks - 1],
scatter_gather_tensors=self.scatter_gather_tensors))
for k in range(num_microbatches_remaining, num_microbatches): for k in range(num_microbatches_remaining, num_microbatches):
input_tensor_grad = backward_step_helper(k) input_tensor_grad = backward_step_helper(k)
next_backward_model_chunk_id = get_model_chunk_id(k+1, forward=False) next_backward_model_chunk_id = get_model_chunk_id(k + 1, forward=False)
recv_next = True recv_next = True
if gpc.is_pipeline_last_stage(ignore_virtual=True): if gpc.is_pipeline_last_stage(ignore_virtual=True):
if next_backward_model_chunk_id == (num_model_chunks - 1): if next_backward_model_chunk_id == (num_model_chunks - 1):
@ -696,12 +701,11 @@ class InterleavedPipelineSchedule(PipelineSchedule):
recv_next = False recv_next = False
output_shape = output_tensor_shapes[next_backward_model_chunk_id] if recv_next else None output_shape = output_tensor_shapes[next_backward_model_chunk_id] if recv_next else None
output_tensor_grads[next_backward_model_chunk_id].append( output_tensor_grads[next_backward_model_chunk_id].append(
comm.send_backward_recv_backward( comm.send_backward_recv_backward(input_tensor_grad,
input_tensor_grad, output_shape,
output_shape, recv_next=recv_next,
recv_next=recv_next, dtype=self.dtype,
dtype=self.dtype, scatter_gather_tensors=self.scatter_gather_tensors))
scatter_gather_tensors=self.scatter_gather_tensors))
if len(return_tensors) > 0: if len(return_tensors) > 0:
output, label = pack_return_tensors(return_tensors) output, label = pack_return_tensors(return_tensors)

View File

@ -262,3 +262,15 @@ class ShardedModelV2(nn.Module):
def load_state_dict(self, state_dict: 'OrderedDict[str, torch.Tensor]', strict: bool = True): def load_state_dict(self, state_dict: 'OrderedDict[str, torch.Tensor]', strict: bool = True):
raise NotImplementedError raise NotImplementedError
def __getitem__(self, idx: int):
assert isinstance(self.module, nn.ModuleList)
return self.module[idx]
def __len__(self):
assert isinstance(self.module, nn.ModuleList)
return len(self.module)
def __iter__(self):
assert isinstance(self.module, nn.ModuleList)
return iter(self.module)