diff --git a/applications/ColossalChat/rl_example.py b/applications/ColossalChat/rl_example.py index 8b1f69040..d122c7286 100644 --- a/applications/ColossalChat/rl_example.py +++ b/applications/ColossalChat/rl_example.py @@ -144,8 +144,8 @@ if __name__ == "__main__": max_length=args.max_new_tokens + args.max_prompt_tokens, do_sample=True, max_new_tokens=None, - early_stopping=False, - stop_strings=[""], + early_stopping=False if args.reward_type == "think_answer_tags" else True, + stop_strings=[""] if args.reward_type == "think_answer_tags" else None, ) ) elif args.backend == "vllm": @@ -161,9 +161,9 @@ if __name__ == "__main__": generate_config.update( dict( max_tokens=args.max_new_tokens, # max new tokens - ignore_eos=True, + ignore_eos=True if args.reward_type == "think_answer_tags" else False, include_stop_str_in_output=True, - stop=[""], + stop=[""] if args.reward_type == "think_answer_tags" else None, ) ) else: