[example] update gpt gemini example ci test (#2477)

This commit is contained in:
ver217
2023-01-13 22:37:31 +08:00
committed by GitHub
parent fef5c949c3
commit f525d1f528
2 changed files with 14 additions and 16 deletions

View File

@@ -65,6 +65,7 @@ def parse_args():
default="gpt2_medium",
help="model model scale",
)
parser.add_argument("--steps", type=int, default=10, help="num of training steps")
args = parser.parse_args()
return args
@@ -236,7 +237,7 @@ def main():
SEQ_LEN = 1024
VOCAB_SIZE = 50257
NUM_STEPS = 10
NUM_STEPS = args.steps
WARMUP_STEPS = 1
assert WARMUP_STEPS < NUM_STEPS, "warmup steps should smaller than the total steps"
assert (NUM_STEPS - WARMUP_STEPS) % 2 == 1, "the number of valid steps should be odd to take the median "
@@ -290,14 +291,12 @@ def main():
from torch.distributed.optim import ZeroRedundancyOptimizer
optimizer = ZeroRedundancyOptimizer(model.parameters(), optimizer_class=torch.optim.Adam, lr=0.01)
elif args.distplan.startswith("zero"):
pg = ProcessGroup()
model = model.half()
partition_flag = (args.distplan == "zero2")
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
optimizer = LowLevelZeroOptimizer(
optimizer,
pg=pg,
reduce_bucket_size=12 * 1024 * 1024,
overlap_communication=True,
partition_grad=partition_flag,