diff --git a/applications/Chat/examples/ray/mmmt_prompt.py b/applications/Chat/examples/ray/mmmt_prompt.py index 60f049bd5..76929c9d0 100644 --- a/applications/Chat/examples/ray/mmmt_prompt.py +++ b/applications/Chat/examples/ray/mmmt_prompt.py @@ -87,8 +87,8 @@ def main(args): kl_coef=0.1, debug=args.debug, update_lora_weights=not (args.lora_rank == 0), - # sync_models_from_trainers=True, - # generation kwargs: + # sync_models_from_trainers=True, + # generation kwargs: max_length=512, do_sample=True, temperature=1.0, @@ -161,12 +161,10 @@ if __name__ == '__main__': parser.add_argument('--prompt_path', type=str, default=None) parser.add_argument('--num_makers', type=int, default=1) parser.add_argument('--num_trainers', type=int, default=1) - parser.add_argument('--trainer_strategy', - choices=[ - 'ddp', 'colossalai_gemini', 'colossalai_zero2', 'colossalai_gemini_cpu', - 'colossalai_zero2_cpu' - ], - default='ddp') + parser.add_argument( + '--trainer_strategy', + choices=['ddp', 'colossalai_gemini', 'colossalai_zero2', 'colossalai_gemini_cpu', 'colossalai_zero2_cpu'], + default='ddp') parser.add_argument('--maker_strategy', choices=['naive'], default='naive') parser.add_argument('--model', default='gpt2', choices=['gpt2', 'bloom', 'opt', 'llama']) parser.add_argument('--critic_model', default='gpt2', choices=['gpt2', 'bloom', 'opt', 'llama'])