diff --git a/applications/ColossalChat/rl_example.py b/applications/ColossalChat/rl_example.py index d122c7286..39584750c 100644 --- a/applications/ColossalChat/rl_example.py +++ b/applications/ColossalChat/rl_example.py @@ -219,19 +219,18 @@ if __name__ == "__main__": num_generations=args.num_generations, train_model_config=train_model_config, grpo_config=grpo_config, - # plugin_config={ - # "zero_stage": 2, - # }, # for zero - # currently not support tp/pp plugin_config={ - "tp_size": 2, - "pp_size": 2, - "microbatch_size": max( - 1, args.train_microbatch_size // 2 - ), # microbatch size should be set to train_microbatch_size // pp_size - "zero_stage": 0, - "max_norm": 1.0, - }, # for pp + "zero_stage": 2, + }, # for zero + # plugin_config={ + # "tp_size": 2, + # "pp_size": 2, + # "microbatch_size": max( + # 1, args.train_microbatch_size // 2 + # ), # microbatch size should be set to train_microbatch_size // pp_size + # "zero_stage": 0, + # "max_norm": 1.0, + # }, # for pp, tp inference_backend=args.backend, master_addr="localhost", master_port=args.master_port,