add profiling

This commit is contained in:
Tong Li
2025-06-26 17:49:53 +08:00
parent 71ef6b32c6
commit 58cb4fb4f7
2 changed files with 24 additions and 9 deletions

View File

@@ -246,14 +246,25 @@ if __name__ == "__main__":
tensor_parallel_size=args.producer_tensor_parallel_size,
)
)
generate_config.update(
dict(
max_tokens=args.max_new_tokens, # max new tokens
ignore_eos=True if args.reward_type == "think_answer_tags" else False,
include_stop_str_in_output=True,
stop=["</answer>"] if args.reward_type == "think_answer_tags" else None,
if args.enable_profiling:
# If profiling is enabled, we force model to generate to max_new_tokens
inference_model_config.update(
dict(
max_tokens=args.max_new_tokens, # max new tokens
ignore_eos=True,
include_stop_str_in_output=True,
stop=None,
)
)
else:
generate_config.update(
dict(
max_tokens=args.max_new_tokens, # max new tokens
ignore_eos=True if args.reward_type == "think_answer_tags" else False,
include_stop_str_in_output=True,
stop=["</answer>"] if args.reward_type == "think_answer_tags" else None,
)
)
)
eval_generation_config = {"temperature": 0.6} # used to update generation config for evaluation
else:
raise ValueError(f"Unsupported backend: {args.backend}")