mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-07 03:52:01 +00:00
[chat] fix gemini strategy (#4698)
* [chat] fix gemini strategy * [chat] fix gemini strategy * [chat] fix gemini strategy * [chat] fix gemini strategy * g# This is a combination of 2 commits. [chat] fix gemini strategy fox * [chat] fix gemini strategy update llama2 example [chat] fix gemini strategy * [fix] fix gemini strategy * [fix] fix gemini strategy * [fix] fix gemini strategy * [fix] fix gemini strategy * [fix] fix gemini strategy * [fix] fix gemini strategy * [fix] fix gemini strategy * [fix] fix gemini strategy * [fix] fix gemini strategy * [fix] fix gemini strategy * fix * fix * fix * fix * fix * Update train_prompts.py
This commit is contained in:
@@ -76,9 +76,9 @@ def main(args):
|
||||
if args.strategy == "ddp":
|
||||
strategy = DDPStrategy()
|
||||
elif args.strategy == "colossalai_gemini":
|
||||
strategy = GeminiStrategy(placement_policy="cuda", initial_scale=2**5)
|
||||
strategy = GeminiStrategy(placement_policy="static",initial_scale=2**5)
|
||||
elif args.strategy == "colossalai_gemini_cpu":
|
||||
strategy = GeminiStrategy(placement_policy="cpu", initial_scale=2**5)
|
||||
strategy = GeminiStrategy(placement_policy="static", offload_optim_frac=1.0, offload_param_frac=1.0, initial_scale=2**5)
|
||||
elif args.strategy == "colossalai_zero2":
|
||||
strategy = LowLevelZeroStrategy(stage=2, placement_policy="cuda")
|
||||
elif args.strategy == "colossalai_zero2_cpu":
|
||||
|
Reference in New Issue
Block a user