mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-06-20 20:54:55 +00:00
[fix] fix weekly runing example (#4787)
* [fix] fix weekly runing example * [fix] fix weekly runing example
This commit is contained in:
parent
d512a4d38d
commit
26cd6d850c
@ -145,7 +145,7 @@ def main():
|
|||||||
if args.plugin.startswith("torch_ddp"):
|
if args.plugin.startswith("torch_ddp"):
|
||||||
plugin = TorchDDPPlugin()
|
plugin = TorchDDPPlugin()
|
||||||
elif args.plugin == "gemini":
|
elif args.plugin == "gemini":
|
||||||
plugin = GeminiPlugin(placement_policy="cuda", strict_ddp_mode=True, initial_scale=2**5)
|
plugin = GeminiPlugin(placement_policy="static", strict_ddp_mode=True, initial_scale=2**5)
|
||||||
elif args.plugin == "low_level_zero":
|
elif args.plugin == "low_level_zero":
|
||||||
plugin = LowLevelZeroPlugin(initial_scale=2**5)
|
plugin = LowLevelZeroPlugin(initial_scale=2**5)
|
||||||
|
|
||||||
|
@ -165,7 +165,7 @@ def main():
|
|||||||
if args.plugin.startswith("torch_ddp"):
|
if args.plugin.startswith("torch_ddp"):
|
||||||
plugin = TorchDDPPlugin()
|
plugin = TorchDDPPlugin()
|
||||||
elif args.plugin == "gemini":
|
elif args.plugin == "gemini":
|
||||||
plugin = GeminiPlugin(placement_policy="cuda", strict_ddp_mode=True, initial_scale=2**5)
|
plugin = GeminiPlugin(placement_policy="static", strict_ddp_mode=True, initial_scale=2**5)
|
||||||
elif args.plugin == "low_level_zero":
|
elif args.plugin == "low_level_zero":
|
||||||
plugin = LowLevelZeroPlugin(initial_scale=2**5)
|
plugin = LowLevelZeroPlugin(initial_scale=2**5)
|
||||||
|
|
||||||
|
@ -21,7 +21,7 @@ from colossalai.utils import get_current_device
|
|||||||
# ==============================
|
# ==============================
|
||||||
# Prepare Hyperparameters
|
# Prepare Hyperparameters
|
||||||
# ==============================
|
# ==============================
|
||||||
NUM_EPOCHS = 3
|
NUM_EPOCHS = 1
|
||||||
BATCH_SIZE = 32
|
BATCH_SIZE = 32
|
||||||
LEARNING_RATE = 2.4e-5
|
LEARNING_RATE = 2.4e-5
|
||||||
WEIGHT_DECAY = 0.01
|
WEIGHT_DECAY = 0.01
|
||||||
@ -141,7 +141,7 @@ def main():
|
|||||||
if args.plugin.startswith("torch_ddp"):
|
if args.plugin.startswith("torch_ddp"):
|
||||||
plugin = TorchDDPPlugin()
|
plugin = TorchDDPPlugin()
|
||||||
elif args.plugin == "gemini":
|
elif args.plugin == "gemini":
|
||||||
plugin = GeminiPlugin(placement_policy="cuda", strict_ddp_mode=True, initial_scale=2**5)
|
plugin = GeminiPlugin(placement_policy="static", strict_ddp_mode=True, initial_scale=2**5)
|
||||||
elif args.plugin == "low_level_zero":
|
elif args.plugin == "low_level_zero":
|
||||||
plugin = LowLevelZeroPlugin(initial_scale=2**5)
|
plugin = LowLevelZeroPlugin(initial_scale=2**5)
|
||||||
|
|
||||||
|
@ -4,5 +4,5 @@ set -xe
|
|||||||
pip install -r requirements.txt
|
pip install -r requirements.txt
|
||||||
|
|
||||||
for plugin in "torch_ddp" "torch_ddp_fp16" "gemini" "low_level_zero"; do
|
for plugin in "torch_ddp" "torch_ddp_fp16" "gemini" "low_level_zero"; do
|
||||||
torchrun --standalone --nproc_per_node 4 finetune.py --target_f1 0.86 --plugin $plugin
|
torchrun --standalone --nproc_per_node 4 finetune.py --target_f1 0.80 --plugin $plugin
|
||||||
done
|
done
|
||||||
|
Loading…
Reference in New Issue
Block a user