diff --git a/colossalai/booster/plugin/hybrid_parallel_plugin.py b/colossalai/booster/plugin/hybrid_parallel_plugin.py index 8ad9b7956..d33e3485c 100644 --- a/colossalai/booster/plugin/hybrid_parallel_plugin.py +++ b/colossalai/booster/plugin/hybrid_parallel_plugin.py @@ -280,6 +280,7 @@ class HybridParallelPlugin(PipelinePluginBase): enable_flash_attention: bool = False, enable_jit_fused: bool = False, enable_sequence_parallelism: bool = False, + enable_sequence_overlap: bool = False, num_microbatches: Optional[int] = None, microbatch_size: Optional[int] = None, initial_scale: float = 2**16, @@ -341,7 +342,8 @@ class HybridParallelPlugin(PipelinePluginBase): enable_fused_normalization=self.enable_fused_normalization, enable_flash_attention=self.enable_flash_attention, enable_jit_fused=self.enable_jit_fused, - enable_sequence_parallelism=enable_sequence_parallelism) + enable_sequence_parallelism=enable_sequence_parallelism, + enable_sequence_overlap=enable_sequence_overlap) self.amp_config = dict( initial_scale=initial_scale, growth_factor=growth_factor, diff --git a/colossalai/shardformer/layer/_operation.py b/colossalai/shardformer/layer/_operation.py index f45ccc64b..45b305733 100644 --- a/colossalai/shardformer/layer/_operation.py +++ b/colossalai/shardformer/layer/_operation.py @@ -180,7 +180,6 @@ class _LinearWithGatherForwardReduceScatterBackward(torch.autograd.Function): overlap = ctx.overlap if not overlap: - # TODO: overlap SP input with gradient computation input_parallel = _gather(input_, dim, process_group) total_input = input_parallel @@ -191,7 +190,6 @@ class _LinearWithGatherForwardReduceScatterBackward(torch.autograd.Function): grad_output = grad_output.view(-1, grad_output.shape[-1]) total_input = total_input.view(-1, total_input.shape[-1]) - # TODO: overlap SP input with gradient computation if ctx.async_grad_reduce_scatter: # Asynchronous reduce-scatter input_list = [