diff --git a/colossalai/shardformer/policies/llama.py b/colossalai/shardformer/policies/llama.py index 756d32454..2b3a30bad 100644 --- a/colossalai/shardformer/policies/llama.py +++ b/colossalai/shardformer/policies/llama.py @@ -414,10 +414,7 @@ class LlamaForSequenceClassificationPolicy(LlamaPolicy): from transformers import LlamaForSequenceClassification policy = super().module_policy() - if self.pipeline_stage_manager: - use_zbv = self.pipeline_stage_manager.use_zbv - else: - use_zbv = False + use_zbv = self.pipeline_stage_manager is not None and self.pipeline_stage_manager.use_zbv if self.shard_config.enable_tensor_parallelism: # add a new item for sequence classification diff --git a/examples/language/llama/benchmark.py b/examples/language/llama/benchmark.py index b60bdd03e..68ceb9ac1 100644 --- a/examples/language/llama/benchmark.py +++ b/examples/language/llama/benchmark.py @@ -108,6 +108,7 @@ def main(): parser.add_argument("--no_cache", action="store_true") parser.add_argument("--use_fp8_comm", action="store_true", default=False, help="for using fp8 during communication") parser.add_argument("--use_fp8", action="store_true", default=False, help="for using fp8 linear") + parser.add_argument("--overlap_p2p", action="store_true", default=True, help="for using overlap p2p") parser.add_argument("--overlap_allgather", action="store_true") parser.add_argument( "--sp_mode", @@ -256,7 +257,6 @@ def main(): use_fp8=args.use_fp8, fp8_communication=args.use_fp8_comm, scheduler_nodes=scheduler_nodes, - make_vocab_size_divisible_by=1, **hybrid_kwargs, ) elif args.plugin == "3d_cpu": @@ -272,7 +272,7 @@ def main(): microbatch_size=args.mbs, initial_scale=2**8, precision="bf16", - overlap_p2p=True, + overlap_p2p=args.overlap_p2p, use_fp8=args.use_fp8, fp8_communication=args.use_fp8_comm, ) @@ -338,7 +338,7 @@ def main(): torch.set_default_dtype(torch.bfloat16) model, optimizer, _, dataloader, _ = booster.boost(model, optimizer, dataloader=dataloader) - # torch.set_default_dtype(torch.float) + torch.set_default_dtype(torch.float) coordinator.print_on_master( f"Booster init max CUDA memory: {get_accelerator().max_memory_allocated()/1024**2:.2f} MB" )