diff --git a/applications/ColossalChat/coati/dataset/loader.py b/applications/ColossalChat/coati/dataset/loader.py
index 4518fd71f..1b15bfef0 100755
--- a/applications/ColossalChat/coati/dataset/loader.py
+++ b/applications/ColossalChat/coati/dataset/loader.py
@@ -437,3 +437,43 @@ class RawConversationDataset(Dataset):
tokens = apply_chat_template_and_mask(self.tokenizer, message, self.max_length)
self.tokenized_texts[index] = dict(tokens)
return self.tokenized_texts[index]
+
+class AIMEDataset(Dataset):
+ """
+ AIME dataset.
+ """
+
+ def __init__(self, tokenizer: PreTrainedTokenizer, input_file: str, max_length: int) -> None:
+ self.tokenizer = tokenizer
+ self.raw_texts = []
+ with jsonlines.open(input_file) as f:
+ for line in f:
+ self.raw_texts.append(line)
+ self.tokenized_texts = [None] * len(self.raw_texts)
+ self.max_length = max_length
+
+ def __len__(self) -> int:
+ return len(self.raw_texts)
+
+ def __getitem__(self, index: int):
+ if self.tokenized_texts[index] is None:
+ message = self.raw_texts[index]
+ gt_answer = self.tokenizer.encode(message['answer'], padding="max_length", truncation=True, max_length=self.max_length, return_tensors="pt")
+
+ def make_conv_hf(question):
+ msg = [
+ {"role": "user", "content": question}
+ ]
+ return msg
+
+ message = make_conv_hf(message["question"])
+ tokens = apply_chat_template_and_mask(self.tokenizer, message, self.max_length)
+ self.tokenized_texts[index] = dict(tokens)
+ self.tokenized_texts[index]["gt_answer"] = gt_answer.squeeze(1)
+ return self.tokenized_texts[index]
+
+if __name__ == "__main__":
+ from transformers import AutoTokenizer
+ tokenizer = AutoTokenizer.from_pretrained("/home/share/data/model/Qwen2.5-3B")
+ dataset = AIMEDataset(tokenizer, "/home/yanglibing/workspace/PRIME/eval/data/AI-MO/aimo-validation-aime/aimo-validation-aime.jsonl", 512)
+ print(dataset[0])
\ No newline at end of file
diff --git a/applications/ColossalChat/coati/distributed/producer.py b/applications/ColossalChat/coati/distributed/producer.py
index 2c6a24a36..8408693b5 100644
--- a/applications/ColossalChat/coati/distributed/producer.py
+++ b/applications/ColossalChat/coati/distributed/producer.py
@@ -1,12 +1,21 @@
+from collections import defaultdict
+import os
from typing import Any, Dict, Optional
+import numpy as np
import ray
import ray.util.collective as cc
import torch
from coati.dataset.loader import RawConversationDataset
+import wandb
+from applications.ColossalChat.coati.distributed.reward.reward_fn import math_reward_fn
+from coati.distributed.reward.verifiable_reward import VerifiableReward
+from torch import nn
from torch.utils.data import DataLoader, DistributedSampler
from transformers import AutoTokenizer
+from applications.ColossalChat.build.lib.coati.models.utils import read_jsonl_file
+from applications.ColossalChat.coati.dataset.loader import AIMEDataset
from colossalai.utils import get_current_device
from .comm import ray_broadcast_tensor_dict
@@ -90,17 +99,19 @@ class BaseProducer:
def loop(self) -> None:
num_update_per_episode = len(self.dataloader) // self.num_microbatches
num_valid_microbatches = num_update_per_episode * self.num_microbatches
-
print(
f"[P{self.producer_idx}] num_valid_microbatches {num_valid_microbatches}, nmb: {self.num_microbatches}, dl: {len(self.dataloader)}"
)
for episode in range(self.num_episodes):
self.dataloader.sampler.set_epoch(episode)
for i, batch in enumerate(self.dataloader):
+ valid_metrics = self.validate()
+
if i >= num_valid_microbatches:
break
outputs = self.rollout(**batch)
-
+ outputs.update(valid_metrics)
+
print(f"[P{self.producer_idx}] Send data {[(k, v.shape) for k, v in outputs.items()]}")
outputs["temperature"] = torch.tensor(
[self.model.generate_config.temperature] * outputs["input_ids"].size(0)
@@ -176,3 +187,92 @@ class SimpleProducer(BaseProducer):
def load_state_dict(self, state_dict):
self.model.load_state_dict(state_dict)
+
+ def validate(self):
+ all_rewards = []
+ all_formats = []
+ all_accs = []
+ batch_reward_means = []
+
+ self.val_dataset = AIMEDataset(
+ tokenizer=self.tokenizer,
+ input_file="/home/yanglibing/workspace/PRIME/eval/data/AI-MO/aimo-validation-aime/aimo-validation-aime.jsonl",
+ max_length=300,
+ )
+ # Initialize verifiable reward.
+ response_format_tags = {
+ "think_start": {"text": "", "num_occur": 1},
+ "think_end": {"text": "", "num_occur": 1},
+ "answer_start": {"text": "", "num_occur": 1},
+ "answer_end": {"text": "", "num_occur": 1},
+ }
+ self.reward_model = VerifiableReward(
+ reward_fns=[math_reward_fn], tokenizer=self.tokenizer, tags=response_format_tags
+ )
+
+ def collate_fn(data_list: list[dict]) -> dict:
+ tensors = defaultdict(list)
+ non_tensors = defaultdict(list)
+
+ for data in data_list:
+ for key, val in data.items():
+ if isinstance(val, torch.Tensor):
+ tensors[key].append(val)
+ else:
+ non_tensors[key].append(val)
+
+ for key, val in tensors.items():
+ tensors[key] = torch.stack(val, dim=0)
+
+ for key, val in non_tensors.items():
+ non_tensors[key] = np.array(val, dtype=object)
+
+ return {**tensors, **non_tensors}
+
+ self.val_dataloader = DataLoader(dataset=self.val_dataset,
+ batch_size=64,
+ shuffle=True,
+ drop_last=True,
+ collate_fn=collate_fn)
+
+ all_rewards = torch.tensor([], device=self.device)
+ all_formats = torch.tensor([], device=self.device)
+ all_accs = torch.tensor([], device=self.device)
+
+ for test_batch in self.val_dataloader:
+ # test_batch['input_ids'].size() [32, 300]
+ # test_batch["gt_answer"] orch.Size([32, 1, 300])
+ test_output = self.rollout(**test_batch)
+ # test_output["response_idx"] torch.Size([32, 8, 2])
+ num_generations = test_output["response_idx"].size(1)
+ print("num_generations", num_generations)
+ data = {k: v.view(-1, v.size(-1)) for k, v in test_output.items()}
+ # data = test_output
+ reward_group = self.reward_model(
+ data["input_ids"], gt_answer=data["gt_answer"], response_idx=data["response_idx"])
+
+ rewards = torch.stack([x[0] for x in reward_group])
+ format_rewards = torch.stack([x[1] for x in reward_group])
+ acc_rewards = torch.stack([x[2] for x in reward_group])
+
+ all_rewards = torch.cat([all_rewards, rewards])
+ all_formats = torch.cat([all_formats, format_rewards])
+ all_accs = torch.cat([all_accs, acc_rewards])
+
+ avg_reward = torch.mean(all_rewards)
+ avg_format = torch.mean(all_formats)
+ avg_acc = torch.mean(all_accs)
+
+ valid_metrics = {
+ "avg_reward": torch.tensor(avg_reward).unsqueeze(0),
+ "avg_format": torch.tensor(avg_format).unsqueeze(0),
+ "avg_acc": torch.tensor(avg_acc).unsqueeze(0),
+ }
+ print(
+ f"[P{self.producer_idx}] Validation metrics: "
+ f"reward={valid_metrics['avg_reward'].item():.4f}, "
+ f"format={valid_metrics['avg_format'].item():.4f}, "
+ f"acc={valid_metrics['avg_acc'].item():.4f}"
+ )
+ return valid_metrics
+
\ No newline at end of file
diff --git a/applications/ColossalChat/run.sh b/applications/ColossalChat/run.sh
new file mode 100644
index 000000000..68ec2976c
--- /dev/null
+++ b/applications/ColossalChat/run.sh
@@ -0,0 +1,2 @@
+export NCCL_BLOCKING_WAIT=1
+CUDA_VISIBLE_DEVICES=4,5,6,7 python rl_example.py --dataset /home/share/data/dataset/math_competition_train_short.jsonl --model /home/share/data/model/Qwen2.5-3B -t 1 -i 2 -b vllm
\ No newline at end of file