mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-06-22 13:41:43 +00:00
[fix] fix comment in llama & benchmark
This commit is contained in:
parent
fa3ccda8ee
commit
982e4ee1f8
@ -414,10 +414,7 @@ class LlamaForSequenceClassificationPolicy(LlamaPolicy):
|
|||||||
from transformers import LlamaForSequenceClassification
|
from transformers import LlamaForSequenceClassification
|
||||||
|
|
||||||
policy = super().module_policy()
|
policy = super().module_policy()
|
||||||
if self.pipeline_stage_manager:
|
use_zbv = self.pipeline_stage_manager is not None and self.pipeline_stage_manager.use_zbv
|
||||||
use_zbv = self.pipeline_stage_manager.use_zbv
|
|
||||||
else:
|
|
||||||
use_zbv = False
|
|
||||||
|
|
||||||
if self.shard_config.enable_tensor_parallelism:
|
if self.shard_config.enable_tensor_parallelism:
|
||||||
# add a new item for sequence classification
|
# add a new item for sequence classification
|
||||||
|
@ -108,6 +108,7 @@ def main():
|
|||||||
parser.add_argument("--no_cache", action="store_true")
|
parser.add_argument("--no_cache", action="store_true")
|
||||||
parser.add_argument("--use_fp8_comm", action="store_true", default=False, help="for using fp8 during communication")
|
parser.add_argument("--use_fp8_comm", action="store_true", default=False, help="for using fp8 during communication")
|
||||||
parser.add_argument("--use_fp8", action="store_true", default=False, help="for using fp8 linear")
|
parser.add_argument("--use_fp8", action="store_true", default=False, help="for using fp8 linear")
|
||||||
|
parser.add_argument("--overlap_p2p", action="store_true", default=True, help="for using overlap p2p")
|
||||||
parser.add_argument("--overlap_allgather", action="store_true")
|
parser.add_argument("--overlap_allgather", action="store_true")
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--sp_mode",
|
"--sp_mode",
|
||||||
@ -256,7 +257,6 @@ def main():
|
|||||||
use_fp8=args.use_fp8,
|
use_fp8=args.use_fp8,
|
||||||
fp8_communication=args.use_fp8_comm,
|
fp8_communication=args.use_fp8_comm,
|
||||||
scheduler_nodes=scheduler_nodes,
|
scheduler_nodes=scheduler_nodes,
|
||||||
make_vocab_size_divisible_by=1,
|
|
||||||
**hybrid_kwargs,
|
**hybrid_kwargs,
|
||||||
)
|
)
|
||||||
elif args.plugin == "3d_cpu":
|
elif args.plugin == "3d_cpu":
|
||||||
@ -272,7 +272,7 @@ def main():
|
|||||||
microbatch_size=args.mbs,
|
microbatch_size=args.mbs,
|
||||||
initial_scale=2**8,
|
initial_scale=2**8,
|
||||||
precision="bf16",
|
precision="bf16",
|
||||||
overlap_p2p=True,
|
overlap_p2p=args.overlap_p2p,
|
||||||
use_fp8=args.use_fp8,
|
use_fp8=args.use_fp8,
|
||||||
fp8_communication=args.use_fp8_comm,
|
fp8_communication=args.use_fp8_comm,
|
||||||
)
|
)
|
||||||
@ -338,7 +338,7 @@ def main():
|
|||||||
torch.set_default_dtype(torch.bfloat16)
|
torch.set_default_dtype(torch.bfloat16)
|
||||||
model, optimizer, _, dataloader, _ = booster.boost(model, optimizer, dataloader=dataloader)
|
model, optimizer, _, dataloader, _ = booster.boost(model, optimizer, dataloader=dataloader)
|
||||||
|
|
||||||
# torch.set_default_dtype(torch.float)
|
torch.set_default_dtype(torch.float)
|
||||||
coordinator.print_on_master(
|
coordinator.print_on_master(
|
||||||
f"Booster init max CUDA memory: {get_accelerator().max_memory_allocated()/1024**2:.2f} MB"
|
f"Booster init max CUDA memory: {get_accelerator().max_memory_allocated()/1024**2:.2f} MB"
|
||||||
)
|
)
|
||||||
|
Loading…
Reference in New Issue
Block a user