diff --git a/examples/tutorial/new_api/cifar_resnet/train.py b/examples/tutorial/new_api/cifar_resnet/train.py index 6ae2d8b04..4407a51c3 100644 --- a/examples/tutorial/new_api/cifar_resnet/train.py +++ b/examples/tutorial/new_api/cifar_resnet/train.py @@ -145,7 +145,7 @@ def main(): if args.plugin.startswith("torch_ddp"): plugin = TorchDDPPlugin() elif args.plugin == "gemini": - plugin = GeminiPlugin(placement_policy="cuda", strict_ddp_mode=True, initial_scale=2**5) + plugin = GeminiPlugin(placement_policy="static", strict_ddp_mode=True, initial_scale=2**5) elif args.plugin == "low_level_zero": plugin = LowLevelZeroPlugin(initial_scale=2**5) diff --git a/examples/tutorial/new_api/cifar_vit/train.py b/examples/tutorial/new_api/cifar_vit/train.py index 226a4b320..700e4d2e0 100644 --- a/examples/tutorial/new_api/cifar_vit/train.py +++ b/examples/tutorial/new_api/cifar_vit/train.py @@ -165,7 +165,7 @@ def main(): if args.plugin.startswith("torch_ddp"): plugin = TorchDDPPlugin() elif args.plugin == "gemini": - plugin = GeminiPlugin(placement_policy="cuda", strict_ddp_mode=True, initial_scale=2**5) + plugin = GeminiPlugin(placement_policy="static", strict_ddp_mode=True, initial_scale=2**5) elif args.plugin == "low_level_zero": plugin = LowLevelZeroPlugin(initial_scale=2**5) diff --git a/examples/tutorial/new_api/glue_bert/finetune.py b/examples/tutorial/new_api/glue_bert/finetune.py index 7d69dbc06..990822c9f 100644 --- a/examples/tutorial/new_api/glue_bert/finetune.py +++ b/examples/tutorial/new_api/glue_bert/finetune.py @@ -21,7 +21,7 @@ from colossalai.utils import get_current_device # ============================== # Prepare Hyperparameters # ============================== -NUM_EPOCHS = 3 +NUM_EPOCHS = 1 BATCH_SIZE = 32 LEARNING_RATE = 2.4e-5 WEIGHT_DECAY = 0.01 @@ -141,7 +141,7 @@ def main(): if args.plugin.startswith("torch_ddp"): plugin = TorchDDPPlugin() elif args.plugin == "gemini": - plugin = GeminiPlugin(placement_policy="cuda", strict_ddp_mode=True, initial_scale=2**5) + plugin = GeminiPlugin(placement_policy="static", strict_ddp_mode=True, initial_scale=2**5) elif args.plugin == "low_level_zero": plugin = LowLevelZeroPlugin(initial_scale=2**5) diff --git a/examples/tutorial/new_api/glue_bert/test_ci.sh b/examples/tutorial/new_api/glue_bert/test_ci.sh index c2c097f8d..56dd431f1 100755 --- a/examples/tutorial/new_api/glue_bert/test_ci.sh +++ b/examples/tutorial/new_api/glue_bert/test_ci.sh @@ -4,5 +4,5 @@ set -xe pip install -r requirements.txt for plugin in "torch_ddp" "torch_ddp_fp16" "gemini" "low_level_zero"; do - torchrun --standalone --nproc_per_node 4 finetune.py --target_f1 0.86 --plugin $plugin + torchrun --standalone --nproc_per_node 4 finetune.py --target_f1 0.80 --plugin $plugin done