diff --git a/colossalai/shardformer/modeling/mixtral.py b/colossalai/shardformer/modeling/mixtral.py index 4f8ec162f..df9b91da2 100644 --- a/colossalai/shardformer/modeling/mixtral.py +++ b/colossalai/shardformer/modeling/mixtral.py @@ -268,6 +268,7 @@ class MixtralPipelineForwards: return_dict = return_dict if return_dict is not None else self.config.use_return_dict # retrieve input_ids and inputs_embeds + print(f"model_chunk_id {stage_manager.model_chunk_id} stage_manager {stage_manager.stage}") if stage_manager.is_first_stage(): # retrieve input_ids and inputs_embeds if input_ids is not None and inputs_embeds is not None: diff --git a/colossalai/shardformer/policies/mixtral.py b/colossalai/shardformer/policies/mixtral.py index af5b15ed5..a8cd49dc1 100644 --- a/colossalai/shardformer/policies/mixtral.py +++ b/colossalai/shardformer/policies/mixtral.py @@ -343,8 +343,18 @@ class MixtralForCausalLMPolicy(MixtralPolicy): """Get pipeline layers for current stage.""" stage_manager = self.pipeline_stage_manager held_layers = super().get_held_layers() - if stage_manager.is_last_stage(): - held_layers.append(self.model.lm_head) + if stage_manager.is_interleave: + if stage_manager.use_zbv: + if stage_manager.is_first_stage(ignore_chunk=True): + held_layers.append(self.model.lm_head) + else: + if stage_manager.is_last_stage(ignore_chunk=True): + held_layers.append(self.model.lm_head) + else: + if stage_manager.is_last_stage(): + held_layers.append(self.model.lm_head) + # if stage_manager.is_last_stage(): + # held_layers.append(self.model.lm_head) return held_layers def get_shared_params(self) -> List[Dict[int, Tensor]]: diff --git a/examples/language/mixtral/benchmark.py b/examples/language/mixtral/benchmark.py index 7c8a5fe65..2685afced 100644 --- a/examples/language/mixtral/benchmark.py +++ b/examples/language/mixtral/benchmark.py @@ -167,6 +167,7 @@ def main(): enable_fused_normalization=torch.cuda.is_available(), enable_flash_attention=args.xformers, microbatch_size=args.mbs, + num_microbatches=args.batch_size // args.mbs, precision="bf16", enable_metadata_cache=not args.no_cache, overlap_allgather=args.overlap_allgather, @@ -208,8 +209,10 @@ def main(): with init_ctx: model = MixtralForCausalLM(config=config).to(torch.bfloat16) + # if args.grad_checkpoint: + # model.gradient_checkpointing_enable() if args.grad_checkpoint: - model.gradient_checkpointing_enable() + model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={"use_reentrant": False}) model_numel = get_model_numel(model) coordinator.print_on_master(f"Model params: {format_numel_str(model_numel)}") @@ -224,6 +227,7 @@ def main(): ) optimizer = HybridAdam(model.parameters()) + # optimizer = torch.optim.SGD(model.parameters(), lr=1) torch.set_default_dtype(torch.bfloat16) model, optimizer, _, dataloader, _ = booster.boost(model, optimizer, dataloader=dataloader)