fix reward taging bug

This commit is contained in:
YeAnbang 2025-05-03 14:34:04 +08:00
parent da867a4d8f
commit 2999bd4cc8
4 changed files with 29 additions and 32 deletions

View File

@ -128,7 +128,6 @@ class BaseConsumer:
k: eval_statistics[k] + local_eval_result[k] for k in eval_statistics
}
eval_statistics = {"eval/" + k: (v[0] / v[1]).item() for k, v in eval_statistics.items()}
if dist.get_rank() == 0:
if hasattr(self, "wandb_run"):
self.wandb_run.log(eval_statistics, step=eval_global_step)
print(f"Eval statistics: {eval_statistics}")

View File

@ -149,9 +149,13 @@ class GRPOConsumer(BaseConsumer):
def setup(self):
super().setup()
if self.use_wandb and (
if (
self.use_wandb
and self.dp_rank == 0
and (
(not self.plugin.pp_size > 1 and self.rank == 0)
or (self.plugin.pp_size > 1 and self.booster.plugin.stage_manager.is_last_stage() and self.tp_rank == 0)
)
):
# Initialize wandb.
name = f"{self.generate_config['backend']}_bs_{self.batch_size*self.dp_size}_temp_{self.generate_config['temperature']:.01f}_top_p_{self.generate_config['top_p']:.02f}"
@ -482,8 +486,13 @@ class GRPOConsumer(BaseConsumer):
if not self.plugin.pp_size > 1 or (
self.plugin.pp_size > 1 and self.booster.plugin.stage_manager.is_last_stage() and self.tp_rank == 0
):
if (not self.plugin.pp_size > 1 and self.rank == 0) or (
self.plugin.pp_size > 1 and self.booster.plugin.stage_manager.is_last_stage() and self.tp_rank == 0
if self.dp_rank == 0 and (
(not self.plugin.pp_size > 1 and self.rank == 0)
or (
self.plugin.pp_size > 1
and self.booster.plugin.stage_manager.is_last_stage()
and self.tp_rank == 0
)
):
to_log_msg = [
f"Loss: {self.accum_loss.item() / self.accum_count:.4f}",

View File

@ -148,7 +148,7 @@ def boxed_math_reward_fn(input_ids, gt_answer, response_idx, **kwargs):
gt_answer = tokenizer.decode(gt_answer.squeeze(0), skip_special_tokens=True)
final_answer = extract_boxed_solution(decoded_final_answer)
format_valid = final_answer is not None
if "tags" in kwargs:
if "tags" in kwargs and kwargs["tags"]:
tags = kwargs["tags"]
format_valid = format_valid and all(
[decoded_final_answer.count(tags[tag]["text"]) == tags[tag]["num_occur"] for tag in tags]

View File

@ -6,13 +6,6 @@ import ray
import torch
from coati.distributed.launch import launch_distributed
DEFAULT_RESPONSE_FORMAT_TAGS = {
"think_start": {"text": "<think>", "num_occur": 1},
"think_end": {"text": "</think>", "num_occur": 1},
"answer_start": {"text": "<answer>", "num_occur": 1},
"answer_end": {"text": "</answer>", "num_occur": 1},
}
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("-m", "--model", type=str, default="Qwen/Qwen2.5-7B")
@ -248,18 +241,18 @@ if __name__ == "__main__":
num_generations=args.num_generations,
train_model_config=train_model_config,
grpo_config=grpo_config,
plugin_config={
"zero_stage": 2,
}, # for zero
# plugin_config={
# "tp_size": 4,
# "pp_size": 2,
# "microbatch_size": max(
# 1, args.train_microbatch_size // 2
# ), # microbatch size should be set to train_microbatch_size // pp_size
# "zero_stage": 1,
# "max_norm": 1.0,
# }, # for pp, tp
# "zero_stage": 2,
# }, # for zero
plugin_config={
"tp_size": 4,
"pp_size": 2,
"microbatch_size": max(
1, args.train_microbatch_size // 2
), # microbatch size should be set to train_microbatch_size // pp_size
"zero_stage": 1,
"max_norm": 1.0,
}, # for pp, tp
inference_backend=args.backend,
master_port=args.torch_ddp_master_port,
core_algo=args.algo,
@ -272,9 +265,5 @@ if __name__ == "__main__":
},
eval_interval=args.eval_interval,
eval_save_dir=os.path.join(args.eval_save_dir, args.project.replace(" ", "_")),
response_format_tags=(
json.loads(args.response_format_tags)
if args.response_format_tags is not None
else DEFAULT_RESPONSE_FORMAT_TAGS
),
response_format_tags=(json.loads(args.response_format_tags) if args.response_format_tags is not None else {}),
)