mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-06-22 21:49:08 +00:00
fix reward taging bug
This commit is contained in:
parent
da867a4d8f
commit
2999bd4cc8
@ -128,7 +128,6 @@ class BaseConsumer:
|
|||||||
k: eval_statistics[k] + local_eval_result[k] for k in eval_statistics
|
k: eval_statistics[k] + local_eval_result[k] for k in eval_statistics
|
||||||
}
|
}
|
||||||
eval_statistics = {"eval/" + k: (v[0] / v[1]).item() for k, v in eval_statistics.items()}
|
eval_statistics = {"eval/" + k: (v[0] / v[1]).item() for k, v in eval_statistics.items()}
|
||||||
if dist.get_rank() == 0:
|
|
||||||
if hasattr(self, "wandb_run"):
|
if hasattr(self, "wandb_run"):
|
||||||
self.wandb_run.log(eval_statistics, step=eval_global_step)
|
self.wandb_run.log(eval_statistics, step=eval_global_step)
|
||||||
print(f"Eval statistics: {eval_statistics}")
|
print(f"Eval statistics: {eval_statistics}")
|
||||||
|
@ -149,9 +149,13 @@ class GRPOConsumer(BaseConsumer):
|
|||||||
|
|
||||||
def setup(self):
|
def setup(self):
|
||||||
super().setup()
|
super().setup()
|
||||||
if self.use_wandb and (
|
if (
|
||||||
|
self.use_wandb
|
||||||
|
and self.dp_rank == 0
|
||||||
|
and (
|
||||||
(not self.plugin.pp_size > 1 and self.rank == 0)
|
(not self.plugin.pp_size > 1 and self.rank == 0)
|
||||||
or (self.plugin.pp_size > 1 and self.booster.plugin.stage_manager.is_last_stage() and self.tp_rank == 0)
|
or (self.plugin.pp_size > 1 and self.booster.plugin.stage_manager.is_last_stage() and self.tp_rank == 0)
|
||||||
|
)
|
||||||
):
|
):
|
||||||
# Initialize wandb.
|
# Initialize wandb.
|
||||||
name = f"{self.generate_config['backend']}_bs_{self.batch_size*self.dp_size}_temp_{self.generate_config['temperature']:.01f}_top_p_{self.generate_config['top_p']:.02f}"
|
name = f"{self.generate_config['backend']}_bs_{self.batch_size*self.dp_size}_temp_{self.generate_config['temperature']:.01f}_top_p_{self.generate_config['top_p']:.02f}"
|
||||||
@ -482,8 +486,13 @@ class GRPOConsumer(BaseConsumer):
|
|||||||
if not self.plugin.pp_size > 1 or (
|
if not self.plugin.pp_size > 1 or (
|
||||||
self.plugin.pp_size > 1 and self.booster.plugin.stage_manager.is_last_stage() and self.tp_rank == 0
|
self.plugin.pp_size > 1 and self.booster.plugin.stage_manager.is_last_stage() and self.tp_rank == 0
|
||||||
):
|
):
|
||||||
if (not self.plugin.pp_size > 1 and self.rank == 0) or (
|
if self.dp_rank == 0 and (
|
||||||
self.plugin.pp_size > 1 and self.booster.plugin.stage_manager.is_last_stage() and self.tp_rank == 0
|
(not self.plugin.pp_size > 1 and self.rank == 0)
|
||||||
|
or (
|
||||||
|
self.plugin.pp_size > 1
|
||||||
|
and self.booster.plugin.stage_manager.is_last_stage()
|
||||||
|
and self.tp_rank == 0
|
||||||
|
)
|
||||||
):
|
):
|
||||||
to_log_msg = [
|
to_log_msg = [
|
||||||
f"Loss: {self.accum_loss.item() / self.accum_count:.4f}",
|
f"Loss: {self.accum_loss.item() / self.accum_count:.4f}",
|
||||||
|
@ -148,7 +148,7 @@ def boxed_math_reward_fn(input_ids, gt_answer, response_idx, **kwargs):
|
|||||||
gt_answer = tokenizer.decode(gt_answer.squeeze(0), skip_special_tokens=True)
|
gt_answer = tokenizer.decode(gt_answer.squeeze(0), skip_special_tokens=True)
|
||||||
final_answer = extract_boxed_solution(decoded_final_answer)
|
final_answer = extract_boxed_solution(decoded_final_answer)
|
||||||
format_valid = final_answer is not None
|
format_valid = final_answer is not None
|
||||||
if "tags" in kwargs:
|
if "tags" in kwargs and kwargs["tags"]:
|
||||||
tags = kwargs["tags"]
|
tags = kwargs["tags"]
|
||||||
format_valid = format_valid and all(
|
format_valid = format_valid and all(
|
||||||
[decoded_final_answer.count(tags[tag]["text"]) == tags[tag]["num_occur"] for tag in tags]
|
[decoded_final_answer.count(tags[tag]["text"]) == tags[tag]["num_occur"] for tag in tags]
|
||||||
|
@ -6,13 +6,6 @@ import ray
|
|||||||
import torch
|
import torch
|
||||||
from coati.distributed.launch import launch_distributed
|
from coati.distributed.launch import launch_distributed
|
||||||
|
|
||||||
DEFAULT_RESPONSE_FORMAT_TAGS = {
|
|
||||||
"think_start": {"text": "<think>", "num_occur": 1},
|
|
||||||
"think_end": {"text": "</think>", "num_occur": 1},
|
|
||||||
"answer_start": {"text": "<answer>", "num_occur": 1},
|
|
||||||
"answer_end": {"text": "</answer>", "num_occur": 1},
|
|
||||||
}
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument("-m", "--model", type=str, default="Qwen/Qwen2.5-7B")
|
parser.add_argument("-m", "--model", type=str, default="Qwen/Qwen2.5-7B")
|
||||||
@ -248,18 +241,18 @@ if __name__ == "__main__":
|
|||||||
num_generations=args.num_generations,
|
num_generations=args.num_generations,
|
||||||
train_model_config=train_model_config,
|
train_model_config=train_model_config,
|
||||||
grpo_config=grpo_config,
|
grpo_config=grpo_config,
|
||||||
plugin_config={
|
|
||||||
"zero_stage": 2,
|
|
||||||
}, # for zero
|
|
||||||
# plugin_config={
|
# plugin_config={
|
||||||
# "tp_size": 4,
|
# "zero_stage": 2,
|
||||||
# "pp_size": 2,
|
# }, # for zero
|
||||||
# "microbatch_size": max(
|
plugin_config={
|
||||||
# 1, args.train_microbatch_size // 2
|
"tp_size": 4,
|
||||||
# ), # microbatch size should be set to train_microbatch_size // pp_size
|
"pp_size": 2,
|
||||||
# "zero_stage": 1,
|
"microbatch_size": max(
|
||||||
# "max_norm": 1.0,
|
1, args.train_microbatch_size // 2
|
||||||
# }, # for pp, tp
|
), # microbatch size should be set to train_microbatch_size // pp_size
|
||||||
|
"zero_stage": 1,
|
||||||
|
"max_norm": 1.0,
|
||||||
|
}, # for pp, tp
|
||||||
inference_backend=args.backend,
|
inference_backend=args.backend,
|
||||||
master_port=args.torch_ddp_master_port,
|
master_port=args.torch_ddp_master_port,
|
||||||
core_algo=args.algo,
|
core_algo=args.algo,
|
||||||
@ -272,9 +265,5 @@ if __name__ == "__main__":
|
|||||||
},
|
},
|
||||||
eval_interval=args.eval_interval,
|
eval_interval=args.eval_interval,
|
||||||
eval_save_dir=os.path.join(args.eval_save_dir, args.project.replace(" ", "_")),
|
eval_save_dir=os.path.join(args.eval_save_dir, args.project.replace(" ", "_")),
|
||||||
response_format_tags=(
|
response_format_tags=(json.loads(args.response_format_tags) if args.response_format_tags is not None else {}),
|
||||||
json.loads(args.response_format_tags)
|
|
||||||
if args.response_format_tags is not None
|
|
||||||
else DEFAULT_RESPONSE_FORMAT_TAGS
|
|
||||||
),
|
|
||||||
)
|
)
|
||||||
|
Loading…
Reference in New Issue
Block a user