diff --git a/colossalai/engine/gradient_handler/_pipeline_parallel_gradient_handler.py b/colossalai/engine/gradient_handler/_pipeline_parallel_gradient_handler.py index 458a11509..8b6816f91 100644 --- a/colossalai/engine/gradient_handler/_pipeline_parallel_gradient_handler.py +++ b/colossalai/engine/gradient_handler/_pipeline_parallel_gradient_handler.py @@ -1,12 +1,14 @@ #!/usr/bin/env python -import torch.distributed as dist -from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors +from collections import defaultdict +import torch +import torch.distributed as dist from colossalai.core import global_context as gpc from colossalai.registry import GRADIENT_HANDLER +from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors + from ._base_gradient_handler import BaseGradientHandler -from collections import defaultdict @GRADIENT_HANDLER.register_module @@ -35,7 +37,7 @@ class PipelineSharedModuleGradientHandler(BaseGradientHandler): for group, group_buckets in buckets.items(): for tp, bucket in group_buckets.items(): grads = [param.grad.data for param in bucket] - coalesced = _flatten_dense_tensors(grads) + coalesced = _flatten_dense_tensors(grads).to(torch.cuda.current_device()) dist.all_reduce(coalesced, op=dist.ReduceOp.SUM, group=group) for buf, synced in zip(grads, _unflatten_dense_tensors(coalesced, grads)): buf.copy_(synced) diff --git a/colossalai/engine/schedule/_pipeline_schedule.py b/colossalai/engine/schedule/_pipeline_schedule.py index 24256ccd0..a65ec3275 100644 --- a/colossalai/engine/schedule/_pipeline_schedule.py +++ b/colossalai/engine/schedule/_pipeline_schedule.py @@ -12,7 +12,8 @@ from colossalai.core import global_context as gpc from colossalai.logging import get_dist_logger from colossalai.utils import switch_virtual_pipeline_parallel_rank from colossalai.utils.cuda import get_current_device -from colossalai.zero import ShardedOptimizer, ShardedModel +from colossalai.zero import ShardedModel, ShardedOptimizer +from colossalai.zero.sharded_model import ShardedModelV2 from ._base_schedule import BaseSchedule @@ -79,8 +80,8 @@ class PipelineSchedule(BaseSchedule): def _get_data_slice(self, data, offset): if isinstance(data, torch.Tensor): - return data[offset: offset + self.microbatch_size] - else: + return data[offset:offset + self.microbatch_size] + elif isinstance(data, dict): return {k: v[offset:offset + self.microbatch_size] for k, v in data.items()} def load_micro_batch(self): @@ -92,11 +93,9 @@ class PipelineSchedule(BaseSchedule): def pre_processing(self, engine): # TODO: remove this after testing new zero with pipeline parallelism if isinstance(engine.optimizer, ShardedOptimizer) or isinstance(engine.model, ShardedModel): - raise TypeError( - "Pipeline schedule is currently not compatible with ZeRO" - ) + raise TypeError("Pipeline schedule is currently not compatible with ZeRO") model = engine.model - if isinstance(model, NaiveAMPModel): + if isinstance(model, (NaiveAMPModel, ShardedModelV2)): self.dtype = torch.half model = model.model sig = inspect.signature(model.forward) @@ -107,6 +106,8 @@ class PipelineSchedule(BaseSchedule): def _call_engine(model, input_tensor, batch_data): if isinstance(model, NaiveAMPModel): sig = inspect.signature(model.model.forward) + elif isinstance(model, ShardedModelV2): + sig = inspect.signature(model.module.forward) else: sig = inspect.signature(model.forward) if isinstance(batch_data, torch.Tensor): @@ -162,9 +163,11 @@ class PipelineSchedule(BaseSchedule): return output_tensor else: assert isinstance( - output_tensor, torch.Tensor), 'Output of model using pipeline parallelism must be a tensor (except the last stage).' + output_tensor, + torch.Tensor), 'Output of model using pipeline parallelism must be a tensor (except the last stage).' self._logger.debug( - f'Global rank {gpc.get_global_rank()}, pipeline rank {gpc.get_local_rank(ParallelMode.PIPELINE)} forward output tensor {output_tensor.shape}, dtype {output_tensor.dtype}') + f'Global rank {gpc.get_global_rank()}, pipeline rank {gpc.get_local_rank(ParallelMode.PIPELINE)} forward output tensor {output_tensor.shape}, dtype {output_tensor.dtype}' + ) return output_tensor def backward_step(self, engine, input_tensor, output_tensor, output_tensor_grad): @@ -203,12 +206,7 @@ class PipelineSchedule(BaseSchedule): return input_tensor_grad - def forward_backward_step(self, - engine, - data_iter, - forward_only=False, - return_loss=True, - return_output_label=True): + def forward_backward_step(self, engine, data_iter, forward_only=False, return_loss=True, return_output_label=True): """Runs non-interleaved 1F1B schedule, with communication between pipeline stages. Returns a tuple with losses if the last stage, an empty tuple otherwise. @@ -231,10 +229,9 @@ class PipelineSchedule(BaseSchedule): 'The argument \'return_loss\' has to be True when \'forward_only\' is False, but got False.' self.load_batch(data_iter) num_warmup_microbatches = \ - (gpc.get_world_size(ParallelMode.PIPELINE) - - gpc.get_local_rank(ParallelMode.PIPELINE) - 1) - num_warmup_microbatches = min(num_warmup_microbatches, - self.num_microbatches) + (gpc.get_world_size(ParallelMode.PIPELINE) + - gpc.get_local_rank(ParallelMode.PIPELINE) - 1) + num_warmup_microbatches = min(num_warmup_microbatches, self.num_microbatches) num_microbatches_remaining = self.num_microbatches - num_warmup_microbatches # Input, output tensors only need to be saved when doing backward passes @@ -257,13 +254,14 @@ class PipelineSchedule(BaseSchedule): for i in range(num_warmup_microbatches): if not gpc.is_first_rank(ParallelMode.PIPELINE): ft_shape = comm.recv_tensor_meta(ft_shape) - input_tensor = comm.recv_forward(ft_shape, dtype=self.dtype, + input_tensor = comm.recv_forward(ft_shape, + dtype=self.dtype, scatter_gather_tensors=self.scatter_gather_tensors) - output_tensor = self.forward_step( - engine, input_tensor, return_tensors, - return_output_label=return_output_label, - accum_loss=accum_loss - ) + output_tensor = self.forward_step(engine, + input_tensor, + return_tensors, + return_output_label=return_output_label, + accum_loss=accum_loss) if not gpc.is_last_rank(ParallelMode.PIPELINE): bt_shape = output_tensor.shape fs_checker = comm.send_tensor_meta(output_tensor, fs_checker) @@ -279,28 +277,32 @@ class PipelineSchedule(BaseSchedule): if num_microbatches_remaining > 0: if not gpc.is_first_rank(ParallelMode.PIPELINE): ft_shape = comm.recv_tensor_meta(ft_shape) - input_tensor = comm.recv_forward(ft_shape, dtype=self.dtype, + input_tensor = comm.recv_forward(ft_shape, + dtype=self.dtype, scatter_gather_tensors=self.scatter_gather_tensors) # Run 1F1B in steady state. for i in range(num_microbatches_remaining): last_iteration = (i == (num_microbatches_remaining - 1)) - output_tensor = self.forward_step( - engine, input_tensor, return_tensors, - return_output_label=return_output_label, - accum_loss=accum_loss - ) + output_tensor = self.forward_step(engine, + input_tensor, + return_tensors, + return_output_label=return_output_label, + accum_loss=accum_loss) if forward_only: comm.send_forward(output_tensor, scatter_gather_tensors=self.scatter_gather_tensors) if not last_iteration: - input_tensor = comm.recv_forward(ft_shape, dtype=self.dtype, + input_tensor = comm.recv_forward(ft_shape, + dtype=self.dtype, scatter_gather_tensors=self.scatter_gather_tensors) else: - output_tensor_grad = comm.send_forward_recv_backward( - output_tensor, bt_shape, dtype=self.dtype, scatter_gather_tensors=self.scatter_gather_tensors) + output_tensor_grad = comm.send_forward_recv_backward(output_tensor, + bt_shape, + dtype=self.dtype, + scatter_gather_tensors=self.scatter_gather_tensors) # Add input_tensor and output_tensor to end of list. input_tensors.append(input_tensor) @@ -311,18 +313,16 @@ class PipelineSchedule(BaseSchedule): input_tensor = input_tensors.pop(0) output_tensor = output_tensors.pop(0) - input_tensor_grad = self.backward_step( - engine, - input_tensor, output_tensor, - output_tensor_grad - ) + input_tensor_grad = self.backward_step(engine, input_tensor, output_tensor, output_tensor_grad) if last_iteration: input_tensor = None comm.send_backward(input_tensor_grad, scatter_gather_tensors=self.scatter_gather_tensors) else: - input_tensor = comm.send_backward_recv_forward( - input_tensor_grad, ft_shape, dtype=self.dtype, scatter_gather_tensors=self.scatter_gather_tensors) + input_tensor = comm.send_backward_recv_forward(input_tensor_grad, + ft_shape, + dtype=self.dtype, + scatter_gather_tensors=self.scatter_gather_tensors) # Run cooldown backward passes. if not forward_only: @@ -330,14 +330,11 @@ class PipelineSchedule(BaseSchedule): input_tensor = input_tensors.pop(0) output_tensor = output_tensors.pop(0) - output_tensor_grad = comm.recv_backward(bt_shape, dtype=self.dtype, + output_tensor_grad = comm.recv_backward(bt_shape, + dtype=self.dtype, scatter_gather_tensors=self.scatter_gather_tensors) - input_tensor_grad = self.backward_step( - engine, - input_tensor, output_tensor, - output_tensor_grad - ) + input_tensor_grad = self.backward_step(engine, input_tensor, output_tensor, output_tensor_grad) comm.send_backward(input_tensor_grad, scatter_gather_tensors=self.scatter_gather_tensors) @@ -349,6 +346,7 @@ class PipelineSchedule(BaseSchedule): class InterleavedPipelineSchedule(PipelineSchedule): + def __init__(self, num_microbatches, num_model_chunks, @@ -372,21 +370,19 @@ class InterleavedPipelineSchedule(PipelineSchedule): """ assert num_microbatches % gpc.get_world_size(ParallelMode.PIPELINE) == 0, \ 'num_microbatches must be an integer multiple of pipeline parallel world size' - super().__init__(num_microbatches, batch_data_process_func=batch_data_process_func, - tensor_shape=tensor_shape, scatter_gather_tensors=scatter_gather_tensors) + super().__init__(num_microbatches, + batch_data_process_func=batch_data_process_func, + tensor_shape=tensor_shape, + scatter_gather_tensors=scatter_gather_tensors) gpc.set_virtual_pipeline_parallel_size(num_model_chunks) gpc.set_virtual_pipeline_parallel_rank(0) self.num_model_chunks = num_model_chunks def pre_processing(self, engine): - if isinstance(engine.optimizer, (ZeroRedundancyOptimizer_Level_2, ZeroRedundancyOptimizer_Level_3)): - raise TypeError( - "Pipeline schedule is currently not compatible with ZeRO Level 2 and Level 3" - ) - - if isinstance(engine.model[0], NaiveAMPModel): + if isinstance(engine.model, ShardedModelV2): + self.dtype = torch.half + elif isinstance(engine.model[0], NaiveAMPModel): self.dtype = torch.half - for model in engine.model: if isinstance(model, NaiveAMPModel): model = model.model @@ -405,7 +401,13 @@ class InterleavedPipelineSchedule(PipelineSchedule): self.microbatch_offset[model_chunk_id] += self.microbatch_size return self._move_to_device(data), self._move_to_device(label) - def forward_step(self, engine, model_chunk_id, input_tensor, return_tensors, return_output_label=True, accum_loss=None): + def forward_step(self, + engine, + model_chunk_id, + input_tensor, + return_tensors, + return_output_label=True, + accum_loss=None): """Forward step for passed-in model. If it is the first stage, the input tensor is obtained from data_iterator, otherwise the passed-in input_tensor is used. Returns output tensor. This is a helper function and can be ignored by users. @@ -425,9 +427,11 @@ class InterleavedPipelineSchedule(PipelineSchedule): return output_tensor else: assert isinstance( - output_tensor, torch.Tensor), 'Output of model using pipeline parallelism must be a tensor (except the last stage).' + output_tensor, + torch.Tensor), 'Output of model using pipeline parallelism must be a tensor (except the last stage).' self._logger.debug( - f'Global rank {gpc.get_global_rank()}, pipeline rank {gpc.get_local_rank(ParallelMode.PIPELINE)} forward output tensor {output_tensor.shape}, dtype {output_tensor.dtype}') + f'Global rank {gpc.get_global_rank()}, pipeline rank {gpc.get_local_rank(ParallelMode.PIPELINE)} forward output tensor {output_tensor.shape}, dtype {output_tensor.dtype}' + ) return output_tensor def forward_backward_step(self, engine, data_iter, forward_only=False, return_loss=True, return_output_label=True): @@ -488,10 +492,8 @@ class InterleavedPipelineSchedule(PipelineSchedule): else: num_warmup_microbatches = \ (pipeline_parallel_size - pipeline_parallel_rank - 1) * 2 - num_warmup_microbatches += ( - num_model_chunks - 1) * pipeline_parallel_size - num_warmup_microbatches = min(num_warmup_microbatches, - num_microbatches) + num_warmup_microbatches += (num_model_chunks - 1) * pipeline_parallel_size + num_warmup_microbatches = min(num_warmup_microbatches, num_microbatches) num_microbatches_remaining = \ num_microbatches - num_warmup_microbatches @@ -516,8 +518,12 @@ class InterleavedPipelineSchedule(PipelineSchedule): len(output_tensors[model_chunk_id]): input_tensors[model_chunk_id].append(None) input_tensor = input_tensors[model_chunk_id][-1] - output_tensor = self.forward_step(engine, model_chunk_id, input_tensor, - return_tensors, return_output_label=return_output_label, accum_loss=accum_loss) + output_tensor = self.forward_step(engine, + model_chunk_id, + input_tensor, + return_tensors, + return_output_label=return_output_label, + accum_loss=accum_loss) output_tensors[model_chunk_id].append(output_tensor) # if forward-only, no need to save tensors for a backward pass @@ -548,18 +554,20 @@ class InterleavedPipelineSchedule(PipelineSchedule): gpc.set_virtual_pipeline_parallel_rank(0) if not gpc.is_pipeline_first_stage(): input_tensor_shapes[0] = comm.recv_tensor_meta(input_tensor_shapes[0]) - input_tensors[0].append(comm.recv_forward(input_tensor_shapes[0], dtype=self.dtype, - scatter_gather_tensors=self.scatter_gather_tensors)) + input_tensors[0].append( + comm.recv_forward(input_tensor_shapes[0], + dtype=self.dtype, + scatter_gather_tensors=self.scatter_gather_tensors)) for k in range(num_warmup_microbatches): model_chunk_id = get_model_chunk_id(k, forward=True) output_tensor = forward_step_helper(k) if not gpc.is_pipeline_last_stage(): output_tensor_shapes[model_chunk_id] = output_tensor.shape - send_tensor_shape_flags[model_chunk_id] = comm.send_tensor_meta( - output_tensor, send_tensor_shape_flags[model_chunk_id]) + send_tensor_shape_flags[model_chunk_id] = comm.send_tensor_meta(output_tensor, + send_tensor_shape_flags[model_chunk_id]) # Determine if tensor should be received from previous stage. - next_forward_model_chunk_id = get_model_chunk_id(k+1, forward=True) + next_forward_model_chunk_id = get_model_chunk_id(k + 1, forward=True) recv_prev = True if gpc.is_pipeline_first_stage(ignore_virtual=True): if next_forward_model_chunk_id == 0: @@ -584,7 +592,7 @@ class InterleavedPipelineSchedule(PipelineSchedule): recv_next = True if gpc.is_pipeline_last_stage(ignore_virtual=True): recv_next = False - output_shape = output_tensor_shapes[num_model_chunks-1] if recv_next else None + output_shape = output_tensor_shapes[num_model_chunks - 1] if recv_next else None input_tensor, output_tensor_grad = \ comm.send_forward_backward_recv_forward_backward( output_tensor, input_tensor_grad, @@ -593,7 +601,7 @@ class InterleavedPipelineSchedule(PipelineSchedule): recv_prev=recv_prev, recv_next=recv_next, dtype=self.dtype, scatter_gather_tensors=self.scatter_gather_tensors) - output_tensor_grads[num_model_chunks-1].append(output_tensor_grad) + output_tensor_grads[num_model_chunks - 1].append(output_tensor_grad) else: input_tensor = \ comm.send_forward_recv_forward( @@ -634,26 +642,23 @@ class InterleavedPipelineSchedule(PipelineSchedule): recv_prev = True if gpc.is_pipeline_first_stage(ignore_virtual=True): # First stage is ahead of last stage by (pipeline_parallel_size - 1). - next_forward_model_chunk_id = get_model_chunk_id( - forward_k - (pipeline_parallel_size - 1), forward=True) + next_forward_model_chunk_id = get_model_chunk_id(forward_k - (pipeline_parallel_size - 1), forward=True) if next_forward_model_chunk_id == (num_model_chunks - 1): recv_prev = False next_forward_model_chunk_id += 1 else: - next_forward_model_chunk_id = get_model_chunk_id(forward_k + 1, - forward=True) + next_forward_model_chunk_id = get_model_chunk_id(forward_k + 1, forward=True) recv_next = True if gpc.is_pipeline_last_stage(ignore_virtual=True): # Last stage is ahead of first stage by (pipeline_parallel_size - 1). - next_backward_model_chunk_id = get_model_chunk_id( - backward_k - (pipeline_parallel_size - 1), forward=False) + next_backward_model_chunk_id = get_model_chunk_id(backward_k - (pipeline_parallel_size - 1), + forward=False) if next_backward_model_chunk_id == 0: recv_next = False next_backward_model_chunk_id -= 1 else: - next_backward_model_chunk_id = get_model_chunk_id(backward_k + 1, - forward=False) + next_backward_model_chunk_id = get_model_chunk_id(backward_k + 1, forward=False) # If last iteration, don't receive; we already received one extra # before the start of the for loop. @@ -677,17 +682,17 @@ class InterleavedPipelineSchedule(PipelineSchedule): if recv_prev: input_tensors[next_forward_model_chunk_id].append(input_tensor) if recv_next: - output_tensor_grads[next_backward_model_chunk_id].append( - output_tensor_grad) + output_tensor_grads[next_backward_model_chunk_id].append(output_tensor_grad) # Run cooldown backward passes (flush out pipeline). if not forward_only: if all_warmup_microbatches: - output_tensor_grads[num_model_chunks-1].append( - comm.recv_backward(output_tensor_shapes[num_model_chunks-1], scatter_gather_tensors=self.scatter_gather_tensors)) + output_tensor_grads[num_model_chunks - 1].append( + comm.recv_backward(output_tensor_shapes[num_model_chunks - 1], + scatter_gather_tensors=self.scatter_gather_tensors)) for k in range(num_microbatches_remaining, num_microbatches): input_tensor_grad = backward_step_helper(k) - next_backward_model_chunk_id = get_model_chunk_id(k+1, forward=False) + next_backward_model_chunk_id = get_model_chunk_id(k + 1, forward=False) recv_next = True if gpc.is_pipeline_last_stage(ignore_virtual=True): if next_backward_model_chunk_id == (num_model_chunks - 1): @@ -696,12 +701,11 @@ class InterleavedPipelineSchedule(PipelineSchedule): recv_next = False output_shape = output_tensor_shapes[next_backward_model_chunk_id] if recv_next else None output_tensor_grads[next_backward_model_chunk_id].append( - comm.send_backward_recv_backward( - input_tensor_grad, - output_shape, - recv_next=recv_next, - dtype=self.dtype, - scatter_gather_tensors=self.scatter_gather_tensors)) + comm.send_backward_recv_backward(input_tensor_grad, + output_shape, + recv_next=recv_next, + dtype=self.dtype, + scatter_gather_tensors=self.scatter_gather_tensors)) if len(return_tensors) > 0: output, label = pack_return_tensors(return_tensors) diff --git a/colossalai/zero/sharded_model/sharded_model_v2.py b/colossalai/zero/sharded_model/sharded_model_v2.py index 4efd096b1..d8e309d66 100644 --- a/colossalai/zero/sharded_model/sharded_model_v2.py +++ b/colossalai/zero/sharded_model/sharded_model_v2.py @@ -262,3 +262,15 @@ class ShardedModelV2(nn.Module): def load_state_dict(self, state_dict: 'OrderedDict[str, torch.Tensor]', strict: bool = True): raise NotImplementedError + + def __getitem__(self, idx: int): + assert isinstance(self.module, nn.ModuleList) + return self.module[idx] + + def __len__(self): + assert isinstance(self.module, nn.ModuleList) + return len(self.module) + + def __iter__(self): + assert isinstance(self.module, nn.ModuleList) + return iter(self.module)