support AIME validate

This commit is contained in:
ElcarimQAQ 2025-04-18 13:59:34 +08:00
parent 9467c10690
commit ad56d16c1d
3 changed files with 144 additions and 2 deletions

View File

@ -437,3 +437,43 @@ class RawConversationDataset(Dataset):
tokens = apply_chat_template_and_mask(self.tokenizer, message, self.max_length)
self.tokenized_texts[index] = dict(tokens)
return self.tokenized_texts[index]
class AIMEDataset(Dataset):
"""
AIME dataset.
"""
def __init__(self, tokenizer: PreTrainedTokenizer, input_file: str, max_length: int) -> None:
self.tokenizer = tokenizer
self.raw_texts = []
with jsonlines.open(input_file) as f:
for line in f:
self.raw_texts.append(line)
self.tokenized_texts = [None] * len(self.raw_texts)
self.max_length = max_length
def __len__(self) -> int:
return len(self.raw_texts)
def __getitem__(self, index: int):
if self.tokenized_texts[index] is None:
message = self.raw_texts[index]
gt_answer = self.tokenizer.encode(message['answer'], padding="max_length", truncation=True, max_length=self.max_length, return_tensors="pt")
def make_conv_hf(question):
msg = [
{"role": "user", "content": question}
]
return msg
message = make_conv_hf(message["question"])
tokens = apply_chat_template_and_mask(self.tokenizer, message, self.max_length)
self.tokenized_texts[index] = dict(tokens)
self.tokenized_texts[index]["gt_answer"] = gt_answer.squeeze(1)
return self.tokenized_texts[index]
if __name__ == "__main__":
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained("/home/share/data/model/Qwen2.5-3B")
dataset = AIMEDataset(tokenizer, "/home/yanglibing/workspace/PRIME/eval/data/AI-MO/aimo-validation-aime/aimo-validation-aime.jsonl", 512)
print(dataset[0])

View File

@ -1,12 +1,21 @@
from collections import defaultdict
import os
from typing import Any, Dict, Optional
import numpy as np
import ray
import ray.util.collective as cc
import torch
from coati.dataset.loader import RawConversationDataset
import wandb
from applications.ColossalChat.coati.distributed.reward.reward_fn import math_reward_fn
from coati.distributed.reward.verifiable_reward import VerifiableReward
from torch import nn
from torch.utils.data import DataLoader, DistributedSampler
from transformers import AutoTokenizer
from applications.ColossalChat.build.lib.coati.models.utils import read_jsonl_file
from applications.ColossalChat.coati.dataset.loader import AIMEDataset
from colossalai.utils import get_current_device
from .comm import ray_broadcast_tensor_dict
@ -90,17 +99,19 @@ class BaseProducer:
def loop(self) -> None:
num_update_per_episode = len(self.dataloader) // self.num_microbatches
num_valid_microbatches = num_update_per_episode * self.num_microbatches
print(
f"[P{self.producer_idx}] num_valid_microbatches {num_valid_microbatches}, nmb: {self.num_microbatches}, dl: {len(self.dataloader)}"
)
for episode in range(self.num_episodes):
self.dataloader.sampler.set_epoch(episode)
for i, batch in enumerate(self.dataloader):
valid_metrics = self.validate()
if i >= num_valid_microbatches:
break
outputs = self.rollout(**batch)
outputs.update(valid_metrics)
print(f"[P{self.producer_idx}] Send data {[(k, v.shape) for k, v in outputs.items()]}")
outputs["temperature"] = torch.tensor(
[self.model.generate_config.temperature] * outputs["input_ids"].size(0)
@ -176,3 +187,92 @@ class SimpleProducer(BaseProducer):
def load_state_dict(self, state_dict):
self.model.load_state_dict(state_dict)
def validate(self):
all_rewards = []
all_formats = []
all_accs = []
batch_reward_means = []
self.val_dataset = AIMEDataset(
tokenizer=self.tokenizer,
input_file="/home/yanglibing/workspace/PRIME/eval/data/AI-MO/aimo-validation-aime/aimo-validation-aime.jsonl",
max_length=300,
)
# Initialize verifiable reward.
response_format_tags = {
"think_start": {"text": "<think>", "num_occur": 1},
"think_end": {"text": "</think>", "num_occur": 1},
"answer_start": {"text": "<answer>", "num_occur": 1},
"answer_end": {"text": "</answer>", "num_occur": 1},
}
self.reward_model = VerifiableReward(
reward_fns=[math_reward_fn], tokenizer=self.tokenizer, tags=response_format_tags
)
def collate_fn(data_list: list[dict]) -> dict:
tensors = defaultdict(list)
non_tensors = defaultdict(list)
for data in data_list:
for key, val in data.items():
if isinstance(val, torch.Tensor):
tensors[key].append(val)
else:
non_tensors[key].append(val)
for key, val in tensors.items():
tensors[key] = torch.stack(val, dim=0)
for key, val in non_tensors.items():
non_tensors[key] = np.array(val, dtype=object)
return {**tensors, **non_tensors}
self.val_dataloader = DataLoader(dataset=self.val_dataset,
batch_size=64,
shuffle=True,
drop_last=True,
collate_fn=collate_fn)
all_rewards = torch.tensor([], device=self.device)
all_formats = torch.tensor([], device=self.device)
all_accs = torch.tensor([], device=self.device)
for test_batch in self.val_dataloader:
# test_batch['input_ids'].size() [32, 300]
# test_batch["gt_answer"] orch.Size([32, 1, 300])
test_output = self.rollout(**test_batch)
# test_output["response_idx"] torch.Size([32, 8, 2])
num_generations = test_output["response_idx"].size(1)
print("num_generations", num_generations)
data = {k: v.view(-1, v.size(-1)) for k, v in test_output.items()}
# data = test_output
reward_group = self.reward_model(
data["input_ids"], gt_answer=data["gt_answer"], response_idx=data["response_idx"])
rewards = torch.stack([x[0] for x in reward_group])
format_rewards = torch.stack([x[1] for x in reward_group])
acc_rewards = torch.stack([x[2] for x in reward_group])
all_rewards = torch.cat([all_rewards, rewards])
all_formats = torch.cat([all_formats, format_rewards])
all_accs = torch.cat([all_accs, acc_rewards])
avg_reward = torch.mean(all_rewards)
avg_format = torch.mean(all_formats)
avg_acc = torch.mean(all_accs)
valid_metrics = {
"avg_reward": torch.tensor(avg_reward).unsqueeze(0),
"avg_format": torch.tensor(avg_format).unsqueeze(0),
"avg_acc": torch.tensor(avg_acc).unsqueeze(0),
}
print(
f"[P{self.producer_idx}] Validation metrics: "
f"reward={valid_metrics['avg_reward'].item():.4f}, "
f"format={valid_metrics['avg_format'].item():.4f}, "
f"acc={valid_metrics['avg_acc'].item():.4f}"
)
return valid_metrics

View File

@ -0,0 +1,2 @@
export NCCL_BLOCKING_WAIT=1
CUDA_VISIBLE_DEVICES=4,5,6,7 python rl_example.py --dataset /home/share/data/dataset/math_competition_train_short.jsonl --model /home/share/data/model/Qwen2.5-3B -t 1 -i 2 -b vllm