mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-05-07 07:58:27 +00:00
[chat] add zero2 cpu strategy for sft training (#3520)
This commit is contained in:
parent
990d4c3e4e
commit
89fd10a1c9
@ -35,6 +35,8 @@ def train(args):
|
|||||||
strategy = ColossalAIStrategy(stage=3, placement_policy='cuda')
|
strategy = ColossalAIStrategy(stage=3, placement_policy='cuda')
|
||||||
elif args.strategy == 'colossalai_zero2':
|
elif args.strategy == 'colossalai_zero2':
|
||||||
strategy = ColossalAIStrategy(stage=2, placement_policy='cuda')
|
strategy = ColossalAIStrategy(stage=2, placement_policy='cuda')
|
||||||
|
elif args.strategy == 'colossalai_zero2_cpu':
|
||||||
|
strategy = ColossalAIStrategy(stage=2, placement_policy='cpu')
|
||||||
else:
|
else:
|
||||||
raise ValueError(f'Unsupported strategy "{args.strategy}"')
|
raise ValueError(f'Unsupported strategy "{args.strategy}"')
|
||||||
|
|
||||||
@ -168,7 +170,7 @@ def train(args):
|
|||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument('--strategy',
|
parser.add_argument('--strategy',
|
||||||
choices=['naive', 'ddp', 'colossalai_gemini', 'colossalai_zero2'],
|
choices=['naive', 'ddp', 'colossalai_gemini', 'colossalai_zero2', 'colossalai_zero2_cpu'],
|
||||||
default='naive')
|
default='naive')
|
||||||
parser.add_argument('--model', choices=['gpt2', 'bloom', 'opt', 'llama'], default='bloom')
|
parser.add_argument('--model', choices=['gpt2', 'bloom', 'opt', 'llama'], default='bloom')
|
||||||
parser.add_argument('--pretrain', type=str, default=None)
|
parser.add_argument('--pretrain', type=str, default=None)
|
||||||
|
Loading…
Reference in New Issue
Block a user