From 89fd10a1c9189c7961c7750309a94e1ae6b623d4 Mon Sep 17 00:00:00 2001 From: ver217 Date: Mon, 10 Apr 2023 19:00:13 +0800 Subject: [PATCH] [chat] add zero2 cpu strategy for sft training (#3520) --- applications/Chat/examples/train_sft.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/applications/Chat/examples/train_sft.py b/applications/Chat/examples/train_sft.py index c0ac7b177..22f70e485 100644 --- a/applications/Chat/examples/train_sft.py +++ b/applications/Chat/examples/train_sft.py @@ -35,6 +35,8 @@ def train(args): strategy = ColossalAIStrategy(stage=3, placement_policy='cuda') elif args.strategy == 'colossalai_zero2': strategy = ColossalAIStrategy(stage=2, placement_policy='cuda') + elif args.strategy == 'colossalai_zero2_cpu': + strategy = ColossalAIStrategy(stage=2, placement_policy='cpu') else: raise ValueError(f'Unsupported strategy "{args.strategy}"') @@ -168,7 +170,7 @@ def train(args): if __name__ == '__main__': parser = argparse.ArgumentParser() parser.add_argument('--strategy', - choices=['naive', 'ddp', 'colossalai_gemini', 'colossalai_zero2'], + choices=['naive', 'ddp', 'colossalai_gemini', 'colossalai_zero2', 'colossalai_zero2_cpu'], default='naive') parser.add_argument('--model', choices=['gpt2', 'bloom', 'opt', 'llama'], default='bloom') parser.add_argument('--pretrain', type=str, default=None)