diff --git a/.gitignore b/.gitignore index 533450a7c..d48035a5b 100644 --- a/.gitignore +++ b/.gitignore @@ -165,3 +165,4 @@ applications/ColossalChat/logs applications/ColossalChat/tests/logs applications/ColossalChat/wandb applications/ColossalChat/model +applications/ColossalChat/eval diff --git a/applications/ColossalChat/coati/distributed/consumer.py b/applications/ColossalChat/coati/distributed/consumer.py index 9503d65eb..620ab6ff6 100644 --- a/applications/ColossalChat/coati/distributed/consumer.py +++ b/applications/ColossalChat/coati/distributed/consumer.py @@ -36,6 +36,7 @@ class BaseConsumer: minibatch_size: int = 1, save_interval: int = 100, save_dir: str = "./model", + eval_interval: int = -1, ): self.num_producers = num_producers self.num_episodes = num_episodes @@ -51,6 +52,7 @@ class BaseConsumer: self.save_dir = save_dir assert batch_size % minibatch_size == 0, "batch_size should be divisible by microbatch_size" self.num_microbatches = batch_size // minibatch_size + self.eval_interval = eval_interval self.model_config = model_config self.plugin_config = plugin_config @@ -92,8 +94,10 @@ class BaseConsumer: if self.rank == 0: cc.init_collective_group(self.num_producers + 1, self.num_producers, group_name="sync_model") - self.buffer = [] + for i in range(self.num_producers): + cc.init_collective_group(self.world_size + 1, self.rank + 1, group_name=f"sync_eval_statistics_{i}") + self.buffer = [] self.recv_cnt = 0 def state_dict(self) -> Dict[str, torch.Tensor]: @@ -110,6 +114,24 @@ class BaseConsumer: with tqdm(range(self.num_update_per_episode), desc=f"Episode {episode}", disable=self.rank != 0) as pbar: for step in pbar: i = 0 + if self.eval_interval > 0 and step % self.eval_interval == 0: + eval_statistics = None + for r in range(self.num_producers): + print(f"[T{dist.get_rank()}] Recv eval result episode {episode} step {step} from {r}") + local_eval_result = ray_broadcast_tensor_dict( + None, src=0, device=self.device, group_name=f"sync_eval_statistics_{r}" + ) + if eval_statistics is None: + eval_statistics = local_eval_result + else: + eval_statistics = { + k: eval_statistics[k] + local_eval_result[k] for k in eval_statistics + } + eval_statistics = {k: (v[0] / v[1]).item() for k, v in eval_statistics.items()} + if dist.get_rank() == 0: + if hasattr(self, "wandb_run") and hasattr(self, "global_step"): + self.wandb_run.log(eval_statistics, step=self.global_step) + print(f"Eval statistics: {eval_statistics}") for _ in range(self.num_recv_per_update): # receive data from producers for r in range(self.num_producers): @@ -192,6 +214,7 @@ class SimpleConsumer(BaseConsumer): minibatch_size=1, save_interval: int = 100, save_dir="./model", + eval_interval: int = -1, ): super().__init__( num_producers, @@ -206,6 +229,9 @@ class SimpleConsumer(BaseConsumer): model_config, plugin_config, minibatch_size, + save_interval, + save_dir, + eval_interval, ) path = model_config.pop("path") self.model = AutoModelForCausalLM.from_pretrained(path, **model_config) diff --git a/applications/ColossalChat/coati/distributed/grpo_consumer.py b/applications/ColossalChat/coati/distributed/grpo_consumer.py index 877ff98ec..0336125d1 100644 --- a/applications/ColossalChat/coati/distributed/grpo_consumer.py +++ b/applications/ColossalChat/coati/distributed/grpo_consumer.py @@ -40,6 +40,7 @@ class GRPOConsumer(BaseConsumer): project_name=None, save_interval: int = 100, save_dir="./model", + eval_interval: int = -1, ): print(f"Using GRPO config: {grpo_config}") if grpo_config.get("loss_variation", "sample_level") == "token_level": @@ -72,6 +73,7 @@ class GRPOConsumer(BaseConsumer): minibatch_size, save_interval=save_interval, save_dir=save_dir, + eval_interval=eval_interval, ) path = model_config.pop("path") self.policy_model = AutoModelForCausalLM.from_pretrained(path, **model_config) @@ -528,4 +530,5 @@ class GRPOConsumer(BaseConsumer): self.policy_model._force_wait_all_gather() model = self.policy_model.unwrap() state_dict = model.state_dict() + state_dict["consumer_global_step"] = torch.tensor([self.global_step], device=self.device) return state_dict diff --git a/applications/ColossalChat/coati/distributed/inference_backend.py b/applications/ColossalChat/coati/distributed/inference_backend.py index 8ff4b15bf..64dfe5cfd 100644 --- a/applications/ColossalChat/coati/distributed/inference_backend.py +++ b/applications/ColossalChat/coati/distributed/inference_backend.py @@ -205,7 +205,8 @@ class VLLMInferenceBackend(BaseInferenceBackend): generate_config = generate_config.copy() generate_config.update(self.FORCE_GENERATE_CONFIG) generate_config.update({"n": num_generations}) - self.generate_config = SamplingParams(**generate_config) + self.generate_config = generate_config + self.sample_params = SamplingParams(**generate_config) self.model_config = model_config self.tokenizer = tokenizer self.num_generations = num_generations @@ -219,8 +220,9 @@ class VLLMInferenceBackend(BaseInferenceBackend): micro_batch_input_ids_no_padding = [ micro_batch_input_ids[i][first_non_padding_token_idx[i] :] for i in range(micro_batch_size) ] + sample_params = kwargs.get("sample_params", self.sample_params) outputs = self.llm.generate( - prompt_token_ids=micro_batch_input_ids_no_padding, sampling_params=self.generate_config, use_tqdm=False + prompt_token_ids=micro_batch_input_ids_no_padding, sampling_params=sample_params, use_tqdm=False ) out_tokens = [] out_len = [] @@ -266,11 +268,11 @@ class VLLMInferenceBackend(BaseInferenceBackend): "response_idx": response_idx, } - data = {k: v.view(micro_batch_size, self.num_generations, v.size(-1)) for k, v in data.items()} + data = {k: v.view(micro_batch_size, -1, v.size(-1)) for k, v in data.items()} if "gt_answer" in kwargs: # repeat gt_answer for each prompt. - data["gt_answer"] = kwargs["gt_answer"].repeat_interleave(self.num_generations, dim=1) + data["gt_answer"] = kwargs["gt_answer"].repeat_interleave(data["input_ids"].size(1), dim=1) data = {k: v.to(get_current_device()) for k, v in data.items()} return data diff --git a/applications/ColossalChat/coati/distributed/launch.py b/applications/ColossalChat/coati/distributed/launch.py index a346d1d4f..d1b6158e5 100644 --- a/applications/ColossalChat/coati/distributed/launch.py +++ b/applications/ColossalChat/coati/distributed/launch.py @@ -34,7 +34,7 @@ def launch_distributed( inference_microbatch_size: int, train_batch_size: int, train_minibatch_size: int, - dataset_config: Dict[str, Any], + train_dataset_config: Dict[str, Any], dataloaders_config: Dict[str, Any], inference_model_config: Dict[str, Any], generate_config: Dict[str, Any], @@ -50,6 +50,9 @@ def launch_distributed( project_name: Optional[str] = None, save_interval: int = 100, save_dir: str = "./model", + eval_dataset_config: Optional[Dict[str, Any]] = None, + eval_interval: int = 100, + eval_save_dir: Optional[str] = None, ): if core_algo not in ALGO_MAP: @@ -60,9 +63,9 @@ def launch_distributed( train_dp_size = get_dp_size_fast(num_consumer_procs, plugin_config) assert (inference_batch_size * num_producers) % (train_batch_size * train_dp_size) == 0 - dataset_path = dataset_config["path"] + dataset_path = train_dataset_config["path"] num_samples = get_jsonl_size_fast(dataset_path) - global_inference_batch_size = inference_batch_size * num_producers + global_inference_batch_size = inference_batch_size * num_producers # TODO: this doesn't support TP on producer num_update_per_episode = num_samples // global_inference_batch_size num_recv_per_update = inference_batch_size // inference_microbatch_size @@ -74,7 +77,7 @@ def launch_distributed( num_consumer_procs=num_consumer_procs, num_episodes=num_episodes, batch_size=inference_batch_size, - dataset_config=dataset_config, + train_dataset_config=train_dataset_config, dataloaders_config=dataloaders_config, model_config=inference_model_config, generate_config=generate_config, @@ -83,6 +86,10 @@ def launch_distributed( backend=inference_backend, num_generations=num_generations, consumer_plugin_config=plugin_config, + eval_dataset_config=eval_dataset_config, + eval_interval=eval_interval, + evaluation_function_type=grpo_config["reward_fn_type"], + eval_save_dir=eval_save_dir, ) procs.append(producer) generate_config_consumer = copy.deepcopy(generate_config) @@ -111,6 +118,7 @@ def launch_distributed( project_name=project_name, save_interval=save_interval, save_dir=save_dir, + eval_interval=eval_interval, ) procs.append(consumer) ray.get([p.setup.remote() for p in procs]) diff --git a/applications/ColossalChat/coati/distributed/producer.py b/applications/ColossalChat/coati/distributed/producer.py index ffe3a4428..483ad74b3 100644 --- a/applications/ColossalChat/coati/distributed/producer.py +++ b/applications/ColossalChat/coati/distributed/producer.py @@ -1,9 +1,13 @@ +import copy +import os from typing import Any, Dict, Optional import ray import ray.util.collective as cc import torch +import tqdm from coati.dataset.loader import RawConversationDataset +from coati.distributed.reward.reward_fn import boxed_math_reward_fn, math_reward_fn from torch.utils.data import DataLoader, DistributedSampler from transformers import AutoTokenizer @@ -11,7 +15,12 @@ from colossalai.utils import get_current_device from .comm import ray_broadcast_tensor_dict from .inference_backend import BACKEND_MAP -from .utils import pre_send +from .utils import pre_send, safe_write_jsonl + +try: + from vllm import SamplingParams +except ImportError: + LLM = None class BaseProducer: @@ -22,7 +31,7 @@ class BaseProducer: num_consumer_procs: int, num_episodes: int, batch_size: int, - dataset_config: Dict[str, Any], + train_dataset_config: Dict[str, Any], dataloaders_config: Dict[str, Any], model_config: Dict[str, Any], generate_config: Dict[str, Any], @@ -30,6 +39,10 @@ class BaseProducer: microbatch_size: int = 1, backend: str = "transformers", consumer_plugin_config: Dict[str, Any] = None, + eval_dataset_config=None, + eval_interval=-1, # disable evaluation + evaluation_function_type="think_answer_tags", + eval_save_dir: str = "./eval", ): self.producer_idx = producer_idx self.num_producers = num_producers @@ -40,10 +53,17 @@ class BaseProducer: assert batch_size % microbatch_size == 0 self.num_microbatches = batch_size // microbatch_size - self.dataset_config = dataset_config + self.train_dataset_config = train_dataset_config self.model_config = model_config self.generate_config = generate_config self.tokenizer_config = tokenizer_config + self.consumer_plugin_config = consumer_plugin_config + self.eval_interval = eval_interval + self.eval_save_dir = eval_save_dir + self.consumer_global_step = 0 + + if os.path.exists(self.eval_save_dir): + raise ValueError(f"Eval save dir {self.eval_save_dir} already exists. Please delete it or change the name.") # init tokenizer if tokenizer_config is None: @@ -55,13 +75,13 @@ class BaseProducer: self.tokenizer.padding_side = "left" # init dataloader - dataset_path = dataset_config.pop("path") - self.dataset = RawConversationDataset(self.tokenizer, dataset_path, **dataset_config) - self.dataloader = DataLoader( - self.dataset, + train_dataset_path = train_dataset_config.pop("path") + self.train_dataset = RawConversationDataset(self.tokenizer, train_dataset_path, **train_dataset_config) + self.train_dataloader = DataLoader( + self.train_dataset, batch_size=microbatch_size, sampler=DistributedSampler( - self.dataset, + self.train_dataset, num_replicas=num_producers, rank=producer_idx, shuffle=True, @@ -71,6 +91,36 @@ class BaseProducer: num_workers=4, drop_last=True, ) + + self.eval_dataset_config = eval_dataset_config + if self.eval_dataset_config is not None: + self.eval_dataloaders = {} + for eval_task_name in self.eval_dataset_config: + eval_dataset_path = eval_dataset_config[eval_task_name].pop("path") + eval_dataset = RawConversationDataset( + self.tokenizer, eval_dataset_path, **eval_dataset_config[eval_task_name] + ) + print(f"[P{self.producer_idx}] eval dataset {eval_task_name} size: {len(eval_dataset)}") + self.eval_dataloaders[eval_task_name] = DataLoader( + eval_dataset, + batch_size=microbatch_size, + sampler=DistributedSampler( + eval_dataset, + num_replicas=num_producers, + rank=producer_idx, + shuffle=False, + drop_last=False, + seed=42, + ), + ) + if evaluation_function_type == "think_answer_tags": + self.evaluation_function = math_reward_fn + elif evaluation_function_type == "boxed": + self.evaluation_function = boxed_math_reward_fn + else: + raise ValueError(f"Unknown evaluation function type {evaluation_function_type}") + else: + raise ValueError("eval_dataset_config is not defined") self.device = get_current_device() # init backend @@ -88,6 +138,7 @@ class BaseProducer: cc.init_collective_group(self.num_producers + 1, self.producer_idx, group_name=f"sync_model_{i}") else: cc.init_collective_group(self.num_producers + 1, self.producer_idx, group_name="sync_model") + cc.init_collective_group(1 + self.num_consumer_procs, 0, group_name=f"sync_eval_statistics_{self.producer_idx}") def rollout(self, input_ids: torch.Tensor, attention_mask: torch.Tensor, **kwargs) -> Dict[str, torch.Tensor]: raise NotImplementedError @@ -96,29 +147,64 @@ class BaseProducer: raise NotImplementedError def loop(self) -> None: - num_update_per_episode = len(self.dataloader) // self.num_microbatches + num_update_per_episode = len(self.train_dataloader) // self.num_microbatches num_valid_microbatches = num_update_per_episode * self.num_microbatches print( - f"[P{self.producer_idx}] num_valid_microbatches {num_valid_microbatches}, nmb: {self.num_microbatches}, dl: {len(self.dataloader)}" + f"[P{self.producer_idx}] num_valid_microbatches {num_valid_microbatches}, nmb: {self.num_microbatches}, dl: {len(self.train_dataloader)}" ) for episode in range(self.num_episodes): - self.dataloader.sampler.set_epoch(episode) - for i, batch in enumerate(self.dataloader): + self.train_dataloader.sampler.set_epoch(episode) + for i, batch in enumerate(self.train_dataloader): if i >= num_valid_microbatches: break + if self.eval_interval > 0 and self.eval_dataset_config is not None: + if i % self.eval_interval == 0: + eval_statistics = {} + for eval_task_name in self.eval_dataloaders: + print( + f"[P{self.producer_idx}] Evaluate episode {episode} step {i} on task {eval_task_name}" + ) + eval_results = [] + eval_statistics[eval_task_name] = torch.zeros(2, device=self.device) + for eval_batch in tqdm.tqdm( + self.eval_dataloaders[eval_task_name], disable=self.producer_idx != 0 + ): + eval_outputs = self.rollout(**eval_batch, sample_params=self.eval_sample_params) + eval_results = eval_results + [ + self.evaluation_function( + eval_outputs["input_ids"][m][n], + eval_outputs["gt_answer"][m][n], + eval_outputs["response_idx"][m][n], + tokenizer=self.tokenizer, + eval_mode=True, + ) + for m in range(eval_outputs["input_ids"].size(0)) + for n in range(eval_outputs["input_ids"].size(1)) + ] + eval_statistics[eval_task_name][0] += len( + [res for res in eval_results if res["ans_valid"] == 1] + ) + eval_statistics[eval_task_name][1] += len(eval_results) + # save eval results + result_file_name = os.path.join( + self.eval_save_dir, + f"{eval_task_name}_episode_{episode}_step_{self.consumer_global_step}.jsonl", + ) + # delete the file if it exists + safe_write_jsonl(result_file_name, eval_results) + print(f"[P{self.producer_idx}] Send eval statistics episode {episode} step {i}") + ray_broadcast_tensor_dict( + eval_statistics, + src=0, + device=self.device, + group_name=f"sync_eval_statistics_{self.producer_idx}", + ) outputs = self.rollout(**batch) print(f"[P{self.producer_idx}] Send data {[(k, v.shape) for k, v in outputs.items()]}") outputs["temperature"] = torch.tensor( - [ - ( - self.model.generate_config["temperature"] - if isinstance(self.model.generate_config.temperature, dict) - else self.model.generate_config.temperature - ) - ] - * outputs["input_ids"].size(0) + [self.model.generate_config["temperature"]] * outputs["input_ids"].size(0) ).to(outputs["input_ids"].device) outputs = pre_send(outputs) ray_broadcast_tensor_dict( @@ -150,6 +236,8 @@ class BaseProducer: state_dict = ray_broadcast_tensor_dict( None, self.num_producers, device=self.device, group_name="sync_model" ) + if "consumer_global_step" in state_dict: + self.consumer_global_step = state_dict.pop("consumer_global_step").item() self.load_state_dict(state_dict) del state_dict torch.cuda.empty_cache() @@ -159,7 +247,7 @@ class BaseProducer: self.model.llm.wake_up() # linear annealing for 1 episode, temperature from initial to 0.9 if episode <= 0: - ratio = 1 - (len(self.dataloader) - i) / len(self.dataloader) + ratio = 1 - (len(self.train_dataloader) - i) / len(self.train_dataloader) if isinstance(self.model.generate_config.temperature, dict): self.model.generate_config["temperature"] = (1 - ratio) * self.generate_config[ "temperature" @@ -179,7 +267,7 @@ class SimpleProducer(BaseProducer): num_consumer_procs, num_episodes, batch_size, - dataset_config, + train_dataset_config, dataloaders_config, model_config, generate_config, @@ -188,6 +276,10 @@ class SimpleProducer(BaseProducer): backend="transformers", num_generations: int = 8, consumer_plugin_config=None, + eval_dataset_config=None, + eval_interval=-1, # disable evaluation + evaluation_function_type="think_answer_tags", + eval_save_dir: str = "./eval", ): super().__init__( producer_idx, @@ -195,7 +287,7 @@ class SimpleProducer(BaseProducer): num_consumer_procs, num_episodes, batch_size, - dataset_config, + train_dataset_config, dataloaders_config, model_config, generate_config, @@ -203,14 +295,21 @@ class SimpleProducer(BaseProducer): microbatch_size, backend, consumer_plugin_config, + eval_dataset_config=eval_dataset_config, + eval_interval=eval_interval, + evaluation_function_type=evaluation_function_type, + eval_save_dir=eval_save_dir, ) self.model = self.backend_cls(model_config, generate_config, self.tokenizer, num_generations) + self.eval_generation_config = copy.deepcopy(self.model.generate_config) + self.eval_generation_config["n"] = 1 # use 1 generation for evaluation + self.eval_sample_params = SamplingParams(**self.eval_generation_config) @torch.no_grad() def rollout(self, input_ids, attention_mask, **kwargs): rollouts = self.model.generate(input_ids, attention_mask, **kwargs) - if self.producer_idx == 1: - print("Rollout example:\n", self.tokenizer.decode(rollouts["input_ids"][0][0], skip_special_tokens=True)) + # if self.producer_idx == 1: + # print("Rollout example:\n", self.tokenizer.decode(rollouts["input_ids"][0][0], skip_special_tokens=True)) return rollouts diff --git a/applications/ColossalChat/coati/distributed/reward/reward_fn.py b/applications/ColossalChat/coati/distributed/reward/reward_fn.py index b68c1a92f..14d340dc4 100644 --- a/applications/ColossalChat/coati/distributed/reward/reward_fn.py +++ b/applications/ColossalChat/coati/distributed/reward/reward_fn.py @@ -5,6 +5,7 @@ from .reward_utils import extract_boxed_solution, extract_solution, validate_res def math_reward_fn(input_ids, gt_answer, response_idx, **kwargs): tokenizer = kwargs["tokenizer"] + eval_mode = kwargs.get("eval_mode", False) soft_over_length_punishment = kwargs.get("soft_over_length_punishment", False) acc_score = 10.0 reward = torch.tensor(0.0) @@ -44,36 +45,23 @@ def math_reward_fn(input_ids, gt_answer, response_idx, **kwargs): reward = reward + length_reward - return torch.tensor([reward, format_acc, ans_acc]).to(input_ids.device) - - -def gsm8k_reward_fn(input_ids, **kwargs): - gt_answer = kwargs["gt_answer"] - tokenizer = kwargs["tokenizer"] - s, e = kwargs["response_start"], kwargs["response_end"] - reward = torch.tensor(0.0).to(input_ids.device) - if gt_answer is None: - return reward - decoded_final_answer = tokenizer.decode(input_ids[s : e + 1], skip_special_tokens=True) - final_answer, processed_str = extract_solution(decoded_final_answer) - is_valid = True - try: - int(final_answer.strip()) - except Exception: - is_valid = False - - format_valid = validate_response_structure(processed_str, kwargs["tags"]) - if not is_valid or not format_valid: - return reward + if not eval_mode: + return torch.tensor([reward, format_acc, ans_acc]).to(input_ids.device) else: - reward += 1.0 - if gt_answer.strip().replace(" ", "").lower() == final_answer.strip().replace(" ", "").lower(): - reward = reward + 9.0 - return reward + prompt = tokenizer.decode(input_ids[:s], skip_special_tokens=True) + return { + "prompt": prompt, + "prediction": decoded_final_answer, + "gold": gt_answer, + "parsed": final_answer, + "format_valid": format_acc.item(), + "ans_valid": ans_acc.item(), + } def boxed_math_reward_fn(input_ids, gt_answer, response_idx, **kwargs): tokenizer = kwargs["tokenizer"] + eval_mode = kwargs.get("eval_mode", False) soft_over_length_punishment = kwargs.get("soft_over_length_punishment", False) format_score = 0.0 acc_score = 10.0 @@ -91,7 +79,7 @@ def boxed_math_reward_fn(input_ids, gt_answer, response_idx, **kwargs): length_reward = ((max_length - cache_length) - res_length) / cache_length * acc_score if gt_answer is None: - return reward + return torch.tensor([reward, format_acc, ans_acc]).to(input_ids.device) decoded_final_answer = tokenizer.decode(input_ids[s : e + 1], skip_special_tokens=True) gt_answer = tokenizer.decode(gt_answer.squeeze(0), skip_special_tokens=True) @@ -108,5 +96,15 @@ def boxed_math_reward_fn(input_ids, gt_answer, response_idx, **kwargs): reward += acc_score reward = reward + length_reward - - return torch.tensor([reward, format_acc, ans_acc]).to(input_ids.device) + if not eval_mode: + return torch.tensor([reward, format_acc, ans_acc]).to(input_ids.device) + else: + prompt = tokenizer.decode(input_ids[:s], skip_special_tokens=True) + return { + "prompt": prompt, + "prediction": decoded_final_answer, + "gold": gt_answer, + "parsed": final_answer, + "format_valid": format_acc.item(), + "ans_valid": ans_acc.item(), + } diff --git a/applications/ColossalChat/coati/distributed/utils.py b/applications/ColossalChat/coati/distributed/utils.py index 5f7879669..fb2a0dfd3 100644 --- a/applications/ColossalChat/coati/distributed/utils.py +++ b/applications/ColossalChat/coati/distributed/utils.py @@ -1,6 +1,9 @@ +import json +import os from typing import Any, Dict, List import torch +from filelock import FileLock from colossalai.shardformer.layer.loss import dist_log_prob @@ -130,3 +133,13 @@ def masked_sum(tensor: torch.Tensor, mask: torch.Tensor, dim: int = 1) -> torch. """ tensor = tensor * mask return tensor.sum(dim=dim) + + +def safe_write_jsonl(file_path, data): + with FileLock(file_path + ".lock"): + # Ensure file exists + os.makedirs(os.path.dirname(file_path), exist_ok=True) + with open(file_path, "a", encoding="utf8") as f: + for entry in data: + json_line = json.dumps(entry, ensure_ascii=False) + f.write(json_line + "\n") diff --git a/applications/ColossalChat/rl_example.py b/applications/ColossalChat/rl_example.py index 788e60c2e..b2fafd42e 100644 --- a/applications/ColossalChat/rl_example.py +++ b/applications/ColossalChat/rl_example.py @@ -1,4 +1,5 @@ import argparse +import json import os import ray @@ -8,7 +9,16 @@ from coati.distributed.launch import launch_distributed if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("-m", "--model", type=str, default="Qwen/Qwen2.5-7B") - parser.add_argument("-d", "--dataset", type=str, default="data.jsonl") + parser.add_argument("-d", "--dataset", type=str, default="data_train.jsonl") + parser.add_argument( + "-ed", + "--eval-dataset", + type=str, + default='{"eval task name":"data_eval.jsonl"}', + help="Evaluation dataset for each task, please use json format to specify the dataset for each task. \ + For example: {'task1':'data_eval_task1.jsonl', 'task2':'data_eval_task2.jsonl'}, the jsonl file should be in the same format as the training dataset. \ + The key is the task name, and the value is the path to the jsonl file", + ) parser.add_argument("-p", "--project", type=str, default="GRPO", help="Project name.") parser.add_argument("-e", "--num-episodes", type=int, default=1, help="Number of episodes to train.") @@ -94,11 +104,14 @@ if __name__ == "__main__": choices=["think_answer_tags", "boxed"], help="Reward type for GRPO.", ) + parser.add_argument("-ei", "--eval-interval", type=int, default=100, help="Interval for evaluation.") # Logging/Checkpointing parameters parser.add_argument("-si", "--save-interval", type=int, default=100, help="Interval for saving checkpoints.") parser.add_argument("-sd", "--save-dir", type=str, default="./model", help="Directory for saving checkpoints.") - + parser.add_argument( + "-esd", "--eval-save-dir", type=str, default="./eval", help="Directory for saving evaluation results." + ) args = parser.parse_args() if args.train_minibatch_size is None: @@ -208,7 +221,7 @@ if __name__ == "__main__": inference_microbatch_size=args.inference_microbatch_size, train_batch_size=args.train_batch_size, train_minibatch_size=args.train_minibatch_size, - dataset_config={ + train_dataset_config={ "path": args.dataset, "max_length": args.max_prompt_tokens, "system_prompt": args.system_prompt, @@ -238,4 +251,10 @@ if __name__ == "__main__": project_name=args.project, save_interval=args.save_interval, save_dir=os.path.join(args.save_dir, args.project.replace(" ", "_")), + eval_dataset_config={ + k: {"path": v, "max_length": args.max_prompt_tokens, "system_prompt": args.system_prompt} + for k, v in json.loads(args.eval_dataset).items() + }, + eval_interval=args.eval_interval, + eval_save_dir=os.path.join(args.eval_save_dir, args.project.replace(" ", "_")), )