fix metric calculation

This commit is contained in:
YeAnbang 2025-05-20 18:14:05 +08:00
parent 116621d004
commit 37663386bc
4 changed files with 147 additions and 48 deletions

View File

@ -120,24 +120,85 @@ class BaseConsumer:
raw_batch = unbind_batch(
ray_broadcast_tensor_dict(None, src=0, device=self.device, group_name=f"sync_data_{r}")
)
processed_batch = [
self.prompt_level_filtering(self.calculate_group_reward(group)) for group in raw_batch
]
filtered_batch = [t for t in processed_batch if t is not None]
recv_effective_count = 0
# calculate group reward et al. filtering. As only the filtered group will be used for training (which is incomplete),
# we need to calculate the metrics before filtering here for logging
for group in raw_batch:
group_with_reward = self.calculate_group_reward(group)
group_reward_mean = group_with_reward["reward"].mean().cpu().item()
group_format_acc_mean = group_with_reward["format_acc"].mean().cpu().item()
group_ans_acc_mean = group_with_reward["ans_acc"].mean().cpu().item()
group_response_len = (
(
group_with_reward["response_idx"][:, 1]
- group_with_reward["response_idx"][:, 0]
+ 1
)
.type(torch.float32)
.mean()
.cpu()
.item()
)
filtered_group = self.prompt_level_filtering(group_with_reward)
recv_effective_count += 1 if filtered_group is not None else 0
self.buffer.append(
[
filtered_group,
group_reward_mean,
group_format_acc_mean,
group_ans_acc_mean,
group_response_len,
]
)
if self.filter_range is not None:
print(
f"[T{dist.get_rank()}] Filter recv data: {len(processed_batch)} -> {len(filtered_batch)}"
f"[T{dist.get_rank()}] Filter recv data: {len(raw_batch)} -> {recv_effective_count}"
)
# mapping the effective group to the raw group for indexing
effective_group_to_raw_group_mapping = {}
for buffer_idx in range(len(self.buffer)):
if self.buffer[buffer_idx][0] is not None:
effective_group_to_raw_group_mapping[len(effective_group_to_raw_group_mapping)] = (
buffer_idx
)
self.buffer.extend(filtered_batch)
while len(self.buffer) >= self.dp_size * self.minibatch_size:
batches = self.buffer[
self.dp_rank * self.minibatch_size : (self.dp_rank + 1) * self.minibatch_size
while len(effective_group_to_raw_group_mapping) >= self.dp_size * self.minibatch_size:
# on each dp_rank, we use minibatch_size effective samples to form a batch
batches = [
self.buffer[effective_group_to_raw_group_mapping[i]]
for i in range(
self.dp_rank * self.minibatch_size, (self.dp_rank + 1) * self.minibatch_size
)
]
batch = bind_batch(batches)
# every dp_rank will receive a complete mini-batch, no need to sync within step() later
# each mini-batch use the first self.dp_size * minibatch_size effective samples
raw_mini_batches = self.buffer[
: effective_group_to_raw_group_mapping[self.dp_size * self.minibatch_size - 1] + 1
] # include the last effective sample
raw_mini_batches_metric_dict = {
"raw_train_mini_batch_reward": [t[1] for t in raw_mini_batches],
"raw_train_mini_batch_format_acc": [t[2] for t in raw_mini_batches],
"raw_train_mini_batch_ans_acc": [t[3] for t in raw_mini_batches],
"raw_train_mini_batch_response_len": [t[4] for t in raw_mini_batches],
}
batch = bind_batch([t[0] for t in batches])
batch = post_recv(batch)
loss = self.step(i, pbar, **batch)
self.buffer = self.buffer[self.dp_size * self.minibatch_size :]
loss = self.step(i, pbar, **batch, **raw_mini_batches_metric_dict)
self.buffer = self.buffer[
effective_group_to_raw_group_mapping[self.dp_size * self.minibatch_size - 1] + 1 :
]
# recalculate the effective group to raw group mapping
effective_group_to_raw_group_mapping_size_before = len(effective_group_to_raw_group_mapping)
effective_group_to_raw_group_mapping = {}
for buffer_idx in range(len(self.buffer)):
if self.buffer[buffer_idx][0] is not None:
effective_group_to_raw_group_mapping[len(effective_group_to_raw_group_mapping)] = (
buffer_idx
)
assert (
len(effective_group_to_raw_group_mapping)
== effective_group_to_raw_group_mapping_size_before - self.dp_size * self.minibatch_size
)
if loss is not None:
pbar.set_postfix({"loss": loss})
i += 1

View File

@ -72,12 +72,12 @@ class GRPOConsumer(BaseConsumer):
self.policy_model.gradient_checkpointing_enable()
self.optimizer = HybridAdam(self.policy_model.parameters(), lr=grpo_config.get("lr", 1e-6))
self.accum_loss = torch.zeros(1, device=self.device)
self.accum_reward = torch.zeros(1, device=self.device)
self.accum_kl = torch.zeros(1, device=self.device)
self.accum_format_acc = torch.zeros(1, device=self.device)
self.accum_ans_acc = torch.zeros(1, device=self.device)
self.accum_advantages = torch.zeros(1, device=self.device)
self.accum_response_length = torch.zeros(1, device=self.device)
self.raw_train_batch_reward = []
self.raw_train_batch_format_acc = []
self.raw_train_batch_ans_acc = []
self.raw_train_batch_response_len = []
self.accum_count = 0
self.generate_config = generate_config
self.grpo_config = grpo_config
@ -186,7 +186,11 @@ class GRPOConsumer(BaseConsumer):
[minibatch_size, num_of_generation, prompt_length + response_length] --- <PAD>...<PAD><PROMPT>...<PROMPT><RESPONSE>...<RESPONSE><PAD>...<PAD>.
"""
# Reshape to [minibatch_size x num_of_generation, prompt_length + response_length]
data = {k: v.view(-1, v.size(-1)) for k, v in kwargs.items()}
data = {k: v.view(-1, v.size(-1)) for k, v in kwargs.items() if "raw_train_mini_batch_" not in k}
self.raw_train_batch_reward.extend(kwargs["raw_train_mini_batch_reward"])
self.raw_train_batch_format_acc.extend(kwargs["raw_train_mini_batch_format_acc"])
self.raw_train_batch_ans_acc.extend(kwargs["raw_train_mini_batch_ans_acc"])
self.raw_train_batch_response_len.extend(kwargs["raw_train_mini_batch_response_len"])
action_mask = data["action_mask"]
num_action = action_mask.shape[1]
old_action_log_probs = data["action_log_probs"]
@ -430,11 +434,7 @@ class GRPOConsumer(BaseConsumer):
self.accum_loss.add_(sum(mean_loss) / len(mean_loss))
if self.policy_loss_fn.beta > 0:
self.accum_kl.add_(sum(mean_kl) / len(mean_kl))
self.accum_reward.add_(reward.data)
self.accum_format_acc.add_(format_acc.data)
self.accum_ans_acc.add_(ans_acc.data)
self.accum_advantages.add_(advantages.data)
self.accum_response_length.add_(response_length.data)
self.accum_count += 1
if need_update:
self.optimizer.step()
@ -452,21 +452,33 @@ class GRPOConsumer(BaseConsumer):
if (not self.plugin.pp_size > 1 and self.rank == 0) or (
self.plugin.pp_size > 1 and self.booster.plugin.stage_manager.is_last_stage() and self.tp_rank == 0
):
raw_batch_reward_mean = sum(self.raw_train_batch_reward) / len(self.raw_train_batch_reward)
raw_batch_format_acc_mean = sum(self.raw_train_batch_format_acc) / len(
self.raw_train_batch_format_acc
)
raw_batch_ans_acc_mean = sum(self.raw_train_batch_ans_acc) / len(self.raw_train_batch_ans_acc)
raw_batch_response_len_mean = sum(self.raw_train_batch_response_len) / len(
self.raw_train_batch_response_len
)
self.raw_train_batch_reward = []
self.raw_train_batch_format_acc = []
self.raw_train_batch_ans_acc = []
self.raw_train_batch_response_len = []
to_log_msg = [
f"Loss: {self.accum_loss.item() / self.accum_count:.4f}",
f"Reward: {self.accum_reward.item() / self.accum_count:.4f}",
f"format Reward: {self.accum_format_acc.item() / self.accum_count:.4f}",
f"Acc Reward: {self.accum_ans_acc.item() / self.accum_count:.4f}",
f"Reward: {raw_batch_reward_mean:.4f}",
f"format Reward: {raw_batch_format_acc_mean:.4f}",
f"Acc Reward: {raw_batch_ans_acc_mean:.4f}",
f"Advantages: {self.accum_advantages.item() / self.accum_count:.4f}",
f"Response Length: {self.accum_response_length.item() / self.accum_count:.4f}",
f"Response Length: {raw_batch_response_len_mean:.4f}",
f"Sample_utilization: {sample_utilization:.4f}",
] + ([f"KL: {self.accum_kl.item() / self.accum_count:.4f}"] if self.policy_loss_fn.beta > 0 else [])
print("\n".join(to_log_msg))
metrics = {
"metrics/reward": self.accum_reward.item() / self.accum_count,
"metrics/format_acc": self.accum_format_acc.item() / self.accum_count,
"metrics/ans_acc": self.accum_ans_acc.item() / self.accum_count,
"metrics/response_length": self.accum_response_length.item() / self.accum_count,
"metrics/reward": raw_batch_reward_mean,
"metrics/format_acc": raw_batch_format_acc_mean,
"metrics/ans_acc": raw_batch_ans_acc_mean,
"metrics/response_length": raw_batch_response_len_mean,
"train/loss": self.accum_loss.item() / self.accum_count,
"train/advantages": self.accum_advantages.item() / self.accum_count,
"train/learning_rate": self.lr_scheduler.get_last_lr()[0],
@ -478,12 +490,8 @@ class GRPOConsumer(BaseConsumer):
if self.wandb_run is not None:
self.wandb_run.log(metrics)
self.accum_loss.zero_()
self.accum_reward.zero_()
self.accum_ans_acc.zero_()
self.accum_format_acc.zero_()
self.accum_kl.zero_()
self.accum_advantages.zero_()
self.accum_response_length.zero_()
self.accum_count = 0
return loss_scalar
else:

View File

@ -1,4 +1,5 @@
import copy
import os
import uuid
from typing import Any, Dict, Optional
@ -56,7 +57,7 @@ def launch_distributed(
eval_save_dir: Optional[str] = None,
eval_generation_config: Optional[Dict[str, Any]] = None,
log_rollout_interval: int = 20,
rollout_log_file: str = "./rollout_log.jsonl",
rollout_save_dir: str = "./rollout",
):
if core_algo not in ALGO_MAP:
raise NotImplementedError(f"{core_algo} is not supported yet.")
@ -74,6 +75,10 @@ def launch_distributed(
run_name = f"{inference_backend}_bs_{train_batch_size * train_dp_size}_temp_{generate_config['temperature']:.01f}_top_p_{generate_config['top_p']:.02f}"
wandb_group_name = str(uuid.uuid4())
rollout_log_file = os.path.join(
rollout_save_dir,
f"{project_name}_run_{wandb_group_name}.jsonl",
)
procs = []
for i in range(num_producers):

View File

@ -121,6 +121,34 @@ if __name__ == "__main__":
parser.add_argument(
"-rsd", "--rollout-save-dir", type=str, default="./rollouts", help="Directory for saving rollout loggings."
)
parser.add_argument(
"-tp",
"--tensor-parallel-size",
type=int,
default=1,
help="Tensor parallel size for the inference backend. Please check the generation arguments documentation for your backend.",
)
parser.add_argument(
"-pp",
"--pipeline-parallel-size",
type=int,
default=1,
help="Pipeline parallel size for the inference backend. Please check the generation arguments documentation for your backend.",
)
parser.add_argument(
"-zero",
"--zero-stage",
type=int,
default=0,
help="Zero stage for the inference backend. Please check the generation arguments documentation for your backend.",
)
parser.add_argument(
"-ptp",
"--produce-tensor-parallel-size",
type=int,
default=1,
help="Tensor parallel size for the producer. Please check the generation arguments documentation for your backend.",
)
args = parser.parse_args()
if args.train_minibatch_size is None:
@ -178,7 +206,7 @@ if __name__ == "__main__":
enforce_eager=True,
enable_chunked_prefill=True,
max_model_len=args.max_new_tokens + args.max_prompt_tokens,
tensor_parallel_size=1,
tensor_parallel_size=args.produce_tensor_parallel_size,
)
)
generate_config.update(
@ -228,7 +256,7 @@ if __name__ == "__main__":
launch_distributed(
num_producers=args.num_inferencer,
num_proc_per_producer=inference_model_config.get("tensor_parallel_size", 1),
num_proc_per_producer=inference_model_config.get("tensor_parallel_size", args.produce_tensor_parallel_size),
num_consumer_procs=args.num_trainers,
num_episodes=args.num_episodes,
inference_batch_size=args.inference_batch_size,
@ -247,17 +275,14 @@ if __name__ == "__main__":
train_model_config=train_model_config,
grpo_config=grpo_config,
plugin_config={
"zero_stage": 2,
}, # for zero
# plugin_config={
# "tp_size": 2,
# "pp_size": 2,
# "microbatch_size": max(
# 1, args.train_microbatch_size // 2
# ), # microbatch size should be set to train_microbatch_size // pp_size
# "zero_stage": 0,
# "max_norm": 1.0,
# }, # for pp, tp
"tp_size": args.tensor_parallel_size,
"pp_size": args.pipeline_parallel_size,
"microbatch_size": max(
1, args.train_microbatch_size // args.pipeline_parallel_size
), # microbatch size should be set to train_microbatch_size // pp_size
"zero_stage": args.zero_stage,
"max_norm": 1.0,
}, # for pp, tp
inference_backend=args.backend,
master_addr="localhost",
master_port=args.master_port,
@ -273,5 +298,5 @@ if __name__ == "__main__":
eval_save_dir=os.path.join(args.eval_save_dir, args.project.replace(" ", "_")),
eval_generation_config=eval_generation_config,
log_rollout_interval=20,
rollout_log_file=os.path.join(args.rollout_save_dir, args.project.replace(" ", "_") + ".jsonl"),
rollout_save_dir=args.rollout_save_dir,
)