mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-13 05:01:44 +00:00
Merge branch 'feature/zerobubble' of github.com:hpcaitech/ColossalAI into dev/zero_bubble
This commit is contained in:
@@ -1166,22 +1166,6 @@ class HybridParallelPlugin(PipelinePluginBase):
|
||||
num_microbatch=num_microbatches,
|
||||
microbatch_size=microbatch_size,
|
||||
)
|
||||
elif pp_style == "zbv":
|
||||
self.scheduler = ZeroBubbleVPipeScheduler(
|
||||
stage_manager=self.stage_manager,
|
||||
schedule=scheduler_nodes,
|
||||
num_model_chunks=num_model_chunks,
|
||||
num_microbatch=num_microbatches,
|
||||
microbatch_size=microbatch_size,
|
||||
)
|
||||
elif pp_style == "zbv":
|
||||
self.scheduler = ZeroBubbleVPipeScheduler(
|
||||
stage_manager=self.stage_manager,
|
||||
schedule=scheduler_nodes,
|
||||
num_model_chunks=num_model_chunks,
|
||||
num_microbatch=num_microbatches,
|
||||
microbatch_size=microbatch_size,
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError()
|
||||
if sequence_parallelism_mode == "ring_attn":
|
||||
|
@@ -289,9 +289,9 @@ class LlamaPolicy(Policy):
|
||||
held_layers.append(module.embed_tokens)
|
||||
for start_idx, end_idx in stage_indices:
|
||||
held_layers.extend(module.layers[start_idx:end_idx])
|
||||
if stage_manager.use_zbv and stage_manager.is_first_stage(ignore_chunk=True):
|
||||
held_layers.append(module.norm)
|
||||
elif stage_manager.is_last_stage(ignore_chunk=True):
|
||||
if (stage_manager.use_zbv and stage_manager.is_first_stage(ignore_chunk=True)) or (
|
||||
not stage_manager.use_zbv and stage_manager.is_last_stage(ignore_chunk=True)
|
||||
):
|
||||
held_layers.append(module.norm)
|
||||
|
||||
else:
|
||||
@@ -383,13 +383,15 @@ class LlamaForCausalLMPolicy(LlamaPolicy):
|
||||
"""Get pipeline layers for current stage."""
|
||||
stage_manager = self.pipeline_stage_manager
|
||||
held_layers = super().get_held_layers()
|
||||
if stage_manager.use_zbv and stage_manager.is_first_stage(ignore_chunk=True):
|
||||
held_layers.append(self.model.lm_head)
|
||||
elif stage_manager.is_last_stage(ignore_chunk=True):
|
||||
if (stage_manager.use_zbv and stage_manager.is_first_stage(ignore_chunk=True)) or (
|
||||
not stage_manager.use_zbv and stage_manager.is_last_stage(ignore_chunk=True)
|
||||
):
|
||||
held_layers.append(self.model.lm_head)
|
||||
return held_layers
|
||||
|
||||
def get_shared_params(self) -> List[Dict[int, Tensor]]:
|
||||
if self.pipeline_stage_manager is not None and self.pipeline_stage_manager.use_zbv:
|
||||
return []
|
||||
llama_model = self.model.model
|
||||
if self.pipeline_stage_manager and self.pipeline_stage_manager.num_stages > 1:
|
||||
if (
|
||||
@@ -443,9 +445,9 @@ class LlamaForSequenceClassificationPolicy(LlamaPolicy):
|
||||
"""Get pipeline layers for current stage."""
|
||||
stage_manager = self.pipeline_stage_manager
|
||||
held_layers = super().get_held_layers()
|
||||
if stage_manager.use_zbv and stage_manager.is_first_stage(ignore_chunk=True):
|
||||
held_layers.append(self.model.score)
|
||||
elif stage_manager.is_last_stage(ignore_chunk=True):
|
||||
if (stage_manager.use_zbv and stage_manager.is_first_stage(ignore_chunk=True)) or (
|
||||
not stage_manager.use_zbv and stage_manager.is_last_stage(ignore_chunk=True)
|
||||
):
|
||||
held_layers.append(self.model.score)
|
||||
return held_layers
|
||||
|
||||
|
Reference in New Issue
Block a user