Merge pull request #6314 from hpcaitech/grpo-reward-dev

[feat] upgrade reward functions
This commit is contained in:
YeAnbang 2025-05-20 10:06:00 +08:00 committed by GitHub
commit bcf2459db5
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 183 additions and 51 deletions

1
.gitignore vendored
View File

@ -166,3 +166,4 @@ applications/ColossalChat/tests/logs
applications/ColossalChat/wandb
applications/ColossalChat/model
applications/ColossalChat/eval
applications/ColossalChat/rollouts

View File

@ -120,14 +120,20 @@ class GRPOConsumer(BaseConsumer):
"either max_tokens (vllm) or max_new_tokens (transformers) must be set in generate_config."
)
# Initialize verifiable reward.
response_format_tags = {
"think_start": {"text": "<think>", "num_occur": 1},
"think_end": {"text": "</think>", "num_occur": 1},
"answer_start": {"text": "<answer>", "num_occur": 1},
"answer_end": {"text": "</answer>", "num_occur": 1},
}
response_format_tags = (
{
"think_start": {"text": "<think>", "num_occur": 1},
"think_end": {"text": "</think>", "num_occur": 1},
"answer_start": {"text": "<answer>", "num_occur": 1},
"answer_end": {"text": "</answer>", "num_occur": 1},
}
if grpo_config.get("reward_fn_type") == "think_answer_tags"
else None
)
reward_model_kwargs = {
k: v for k, v in grpo_config.items() if k in ["soft_over_length_punishment", "max_length", "cache_length"]
k: v
for k, v in grpo_config.items()
if k in ["soft_over_length_punishment", "max_new_tokens", "cache_length"]
}
self.reward_model = VerifiableReward(
reward_fns=[

View File

@ -55,6 +55,8 @@ def launch_distributed(
eval_interval: int = 100,
eval_save_dir: Optional[str] = None,
eval_generation_config: Optional[Dict[str, Any]] = None,
log_rollout_interval: int = 20,
rollout_save_dir: str = "./rollout",
):
if core_algo not in ALGO_MAP:
raise NotImplementedError(f"{core_algo} is not supported yet.")
@ -72,6 +74,10 @@ def launch_distributed(
run_name = f"{inference_backend}_bs_{train_batch_size * train_dp_size}_temp_{generate_config['temperature']:.01f}_top_p_{generate_config['top_p']:.02f}"
wandb_group_name = str(uuid.uuid4())
rollout_log_file = os.path.join(
rollout_save_dir,
f"{project_name.replace(' ','_')}_run_{wandb_group_name}.jsonl",
)
procs = []
for i in range(num_producers):
@ -98,6 +104,8 @@ def launch_distributed(
project_name=project_name,
run_name=run_name,
wandb_group_name=wandb_group_name,
log_rollout_interval=log_rollout_interval,
rollout_log_file=rollout_log_file,
)
procs.append(producer)
generate_config_consumer = copy.deepcopy(generate_config)

View File

@ -1,4 +1,5 @@
import copy
import json
import os
from typing import Any, Dict, Optional
@ -49,7 +50,8 @@ class BaseProducer:
project_name: str = None,
run_name: str = None,
wandb_group_name: str = None,
wandb_log_rollout_interval: int = 20,
log_rollout_interval: int = 20,
rollout_log_file: str = "./rollout_log.jsonl",
):
self.producer_idx = producer_idx
self.num_producers = num_producers
@ -70,9 +72,16 @@ class BaseProducer:
self.eval_save_dir = eval_save_dir
self.consumer_global_step = 0
self.eval_mode = False
self.wandb_rollout_data = []
self.wandb_log_rollout_interval = wandb_log_rollout_interval
self.log_rollout_interval = log_rollout_interval
self.latest_rollout_log_step = -1
if producer_idx == 0:
if os.path.exists(rollout_log_file):
raise ValueError(
f"Rollout log file {rollout_log_file} already exists. Please delete it or change the name."
)
else:
os.makedirs(os.path.dirname(rollout_log_file), exist_ok=True)
self.rollout_log_file = open(rollout_log_file, "w", encoding="utf8")
if self.producer_idx == 0:
self.wandb_run = wandb.init(
project=project_name,
@ -320,6 +329,8 @@ class SimpleProducer(BaseProducer):
project_name: str = None,
run_name: str = None,
wandb_group_name: str = None,
log_rollout_interval: int = 20,
rollout_log_file: str = "./rollout_log.jsonl",
):
super().__init__(
producer_idx,
@ -342,6 +353,8 @@ class SimpleProducer(BaseProducer):
project_name=project_name,
run_name=run_name,
wandb_group_name=wandb_group_name,
log_rollout_interval=log_rollout_interval,
rollout_log_file=rollout_log_file,
)
self.model = self.backend_cls(model_config, generate_config, self.tokenizer, num_generations)
self.eval_generation_config = copy.deepcopy(self.model.generate_config)
@ -353,26 +366,31 @@ class SimpleProducer(BaseProducer):
def rollout(self, input_ids, attention_mask, **kwargs):
rollouts = self.model.generate(input_ids, attention_mask, **kwargs)
if self.producer_idx == 0 and not self.eval_mode:
wandb_rollout_data = self.wandb_rollout_data + [
[
str(self.consumer_global_step),
str(self.tokenizer.decode(rollouts["input_ids"][0][0], skip_special_tokens=True)),
]
]
if (
self.consumer_global_step - self.latest_rollout_log_step >= self.wandb_log_rollout_interval
self.consumer_global_step - self.latest_rollout_log_step >= self.log_rollout_interval
or self.latest_rollout_log_step == -1
):
self.wandb_rollout_data = wandb_rollout_data
self.latest_rollout_log_step = self.consumer_global_step
self.wandb_run.log(
{
"rollout/rollout_examples": wandb.Table(
columns=["train_step", "rollout_examples"], data=wandb_rollout_data
new_record = (
json.dumps(
{
"train_step": self.consumer_global_step,
"rollout": self.tokenizer.batch_decode(
rollouts["input_ids"][:, 0], skip_special_tokens=True
),
}
)
}
)
+ "\n"
)
self.rollout_log_file.write(new_record)
self.rollout_log_file.flush()
self.latest_rollout_log_step = self.consumer_global_step
return rollouts
def __del__(self):
if self.producer_idx == 0:
self.wandb_run.finish()
if hasattr(self, "rollout_log_file"):
self.rollout_log_file.close()
def load_state_dict(self, state_dict):
self.model.load_state_dict(state_dict)

View File

@ -1,7 +1,77 @@
import torch
from latex2sympy2_extended import NormalizationConfig
from math_verify import ExprExtractionConfig, LatexExtractionConfig, parse, verify
from .reward_utils import extract_boxed_solution, extract_solution, validate_response_structure
CANNOT_PARSE_GT_ANSWER = -1
CANNOT_PARSE_PREDICTION = -2
SUCCESS = 1
MATCHING_FAIL = 0
def verify_math_representation(completion, gt_answer):
"""
Verify if the completion is a valid math representation of the gt_answer.
"""
if not completion.startswith("\\boxed{"):
completion = "\\boxed{" + completion + "}"
if not gt_answer.startswith("\\boxed{"):
gt_answer = "\\boxed{" + gt_answer + "}"
target = (
ExprExtractionConfig(),
LatexExtractionConfig(
normalization_config=NormalizationConfig(
nits=False,
malformed_operators=False,
basic_latex=True,
boxed="all",
units=True,
),
boxed_match_priority=0,
),
)
if not isinstance(gt_answer, str) or len(gt_answer) == 0:
raise ValueError("gt_answer should be a string, please verify your training data.")
if not isinstance(completion, str) or len(completion) == 0:
return MATCHING_FAIL
try:
parsed_gt_answer = parse(gt_answer, extraction_config=target)
if len(parsed_gt_answer) == 0:
return CANNOT_PARSE_GT_ANSWER
parsed_completion = parse(completion, extraction_config=target)
if len(parsed_completion) == 0:
return CANNOT_PARSE_PREDICTION
if verify(parsed_gt_answer, parsed_completion):
return SUCCESS
else:
return MATCHING_FAIL
except Exception:
return MATCHING_FAIL
def verify_model_answer(decoded_final_answer, gt_answer, ans_acc, acc_score, reward):
math_verify_result = verify_math_representation(decoded_final_answer, gt_answer)
exact_match_result = (
SUCCESS
if decoded_final_answer.strip().replace(" ", "").replace("{", "").replace("}", "").replace(",", "")
== gt_answer.strip().replace(" ", "").replace("{", "").replace("}", "").replace(",", "")
else MATCHING_FAIL
)
if math_verify_result == SUCCESS:
ans_acc += 1
reward += acc_score
elif exact_match_result == SUCCESS:
# sometimes for answers that's not a (valid) math expression, math_verify will fail
ans_acc += 1
if math_verify_result == CANNOT_PARSE_PREDICTION:
reward += (
acc_score / 2
) # not a valid latex math representation, but the answer is correct, receive half of the score
else:
reward += acc_score
return reward, ans_acc
def math_reward_fn(input_ids, gt_answer, response_idx, **kwargs):
tokenizer = kwargs["tokenizer"]
@ -14,15 +84,18 @@ def math_reward_fn(input_ids, gt_answer, response_idx, **kwargs):
s, e = response_idx[0], response_idx[1]
length_reward = 0.0
if soft_over_length_punishment:
max_length = kwargs.get("max_length", 1024 * 4)
cache_length = kwargs.get("cache_length", 512)
res_length = e.item() - s.item() + 1
if max_length - cache_length < res_length < max_length:
length_reward = ((max_length - cache_length) - res_length) / cache_length * acc_score
res_length = e.item() - s.item() + 1
if not eval_mode:
max_new_tokens = kwargs["max_new_tokens"]
else:
max_new_tokens = -1 # for eval mode, we don't need to check the length
if not eval_mode and soft_over_length_punishment:
cache_length = kwargs["cache_length"]
if max_new_tokens - cache_length < res_length < max_new_tokens:
length_reward = ((max_new_tokens - cache_length) - res_length) / cache_length * acc_score
if gt_answer is None:
return reward
raise ValueError("no gt_answer is provided, please check your training dataset.")
decoded_final_answer = tokenizer.decode(input_ids[s : e + 1], skip_special_tokens=True)
gt_answer = tokenizer.decode(gt_answer.squeeze(0), skip_special_tokens=True)
@ -35,15 +108,15 @@ def math_reward_fn(input_ids, gt_answer, response_idx, **kwargs):
format_acc += 1
# Check answer accuracy, answer is considered correct if the answer is correct and the format is valid
if (
format_valid
and final_answer is not None
and gt_answer.strip().replace(" ", "").lower() == final_answer.strip().replace(" ", "").lower()
):
ans_acc += 1
reward += acc_score
if final_answer is not None:
if eval_mode or format_valid:
reward, ans_acc = verify_model_answer(final_answer, gt_answer, ans_acc, acc_score, reward)
if not eval_mode:
reward = reward + length_reward
reward = reward + length_reward
# Check if the sequence is over length
if not eval_mode and res_length >= max_new_tokens:
reward *= 0.0
if not eval_mode:
return torch.tensor([reward, format_acc, ans_acc]).to(input_ids.device)
@ -56,6 +129,8 @@ def math_reward_fn(input_ids, gt_answer, response_idx, **kwargs):
"parsed": final_answer,
"format_valid": format_acc.item(),
"ans_valid": ans_acc.item(),
"response_length": res_length,
"reward": reward.item(),
}
@ -71,31 +146,45 @@ def boxed_math_reward_fn(input_ids, gt_answer, response_idx, **kwargs):
s, e = response_idx[0], response_idx[1]
length_reward = 0.0
if soft_over_length_punishment:
max_length = kwargs.get("max_length", 1024 * 4)
cache_length = kwargs.get("cache_length", 512)
res_length = e.item() - s.item() + 1
if max_length - cache_length < res_length < max_length:
length_reward = ((max_length - cache_length) - res_length) / cache_length * acc_score
res_length = e.item() - s.item() + 1
if not eval_mode:
max_new_tokens = kwargs["max_new_tokens"]
else:
max_new_tokens = -1 # for eval mode, we don't need to check the length
if not eval_mode and soft_over_length_punishment:
cache_length = kwargs["cache_length"]
if max_new_tokens - cache_length < res_length < max_new_tokens:
length_reward = ((max_new_tokens - cache_length) - res_length) / cache_length * acc_score
if gt_answer is None:
return torch.tensor([reward, format_acc, ans_acc]).to(input_ids.device)
raise ValueError("no gt_answer is provided, please check your training dataset.")
decoded_final_answer = tokenizer.decode(input_ids[s : e + 1], skip_special_tokens=True)
gt_answer = tokenizer.decode(gt_answer.squeeze(0), skip_special_tokens=True)
final_answer = extract_boxed_solution(decoded_final_answer)
format_valid = final_answer is not None
if "tags" in kwargs and kwargs["tags"]:
tags = kwargs["tags"]
format_valid = format_valid and all(
[decoded_final_answer.count(tags[tag]["text"]) == tags[tag]["num_occur"] for tag in tags]
)
# Check format accuracy
if format_valid:
format_acc += 1
reward += format_score
# Check answer accuracy, answer is considered correct if the answer is correct and the format is valid
if format_valid and final_answer is not None and gt_answer.strip().lower() == final_answer.strip().lower():
ans_acc += 1
reward += acc_score
if final_answer is not None:
if eval_mode or format_valid:
reward, ans_acc = verify_model_answer(final_answer, gt_answer, ans_acc, acc_score, reward)
if not eval_mode:
reward = reward + length_reward
# Check if the sequence is over length
if not eval_mode and res_length >= max_new_tokens:
reward *= 0.0
reward = reward + length_reward
if not eval_mode:
return torch.tensor([reward, format_acc, ans_acc]).to(input_ids.device)
else:
@ -107,4 +196,6 @@ def boxed_math_reward_fn(input_ids, gt_answer, response_idx, **kwargs):
"parsed": final_answer,
"format_valid": format_acc.item(),
"ans_valid": ans_acc.item(),
"response_length": res_length,
"reward": reward.item(),
}

View File

@ -118,6 +118,9 @@ if __name__ == "__main__":
parser.add_argument(
"-esd", "--eval-save-dir", type=str, default="./eval", help="Directory for saving evaluation results."
)
parser.add_argument(
"-rsd", "--rollout-save-dir", type=str, default="./rollouts", help="Directory for saving rollout loggings."
)
args = parser.parse_args()
if args.train_minibatch_size is None:
@ -198,6 +201,8 @@ if __name__ == "__main__":
"beta": args.kl_coeff, # KL penalty coefficient
"loss_variation": "sample_level",
"reward_fn_type": args.reward_type,
"max_length": args.max_new_tokens + args.max_prompt_tokens,
"max_new_tokens": args.max_new_tokens,
}
elif args.algo == "DAPO":
# DAPO variant settings
@ -213,6 +218,7 @@ if __name__ == "__main__":
"loss_variation": "token_level",
"soft_over_length_punishment": True,
"max_length": args.max_new_tokens + args.max_prompt_tokens,
"max_new_tokens": args.max_new_tokens,
"cache_length": min(1024, int(args.max_new_tokens / 4)),
"filter_truncated_response": True,
"reward_fn_type": args.reward_type,
@ -266,4 +272,6 @@ if __name__ == "__main__":
eval_interval=args.eval_interval,
eval_save_dir=os.path.join(args.eval_save_dir, args.project.replace(" ", "_")),
eval_generation_config=eval_generation_config,
log_rollout_interval=20,
rollout_save_dir=args.rollout_save_dir,
)