diff --git a/.gitignore b/.gitignore index 533450a7c..d48035a5b 100644 --- a/.gitignore +++ b/.gitignore @@ -165,3 +165,4 @@ applications/ColossalChat/logs applications/ColossalChat/tests/logs applications/ColossalChat/wandb applications/ColossalChat/model +applications/ColossalChat/eval diff --git a/applications/ColossalChat/coati/distributed/consumer.py b/applications/ColossalChat/coati/distributed/consumer.py index b07f8d7a1..816cab50a 100644 --- a/applications/ColossalChat/coati/distributed/consumer.py +++ b/applications/ColossalChat/coati/distributed/consumer.py @@ -93,7 +93,6 @@ class BaseConsumer: cc.init_collective_group(self.num_producers + 1, self.num_producers, group_name="sync_model") self.buffer = [] - self.recv_cnt = 0 def state_dict(self) -> Dict[str, torch.Tensor]: @@ -209,6 +208,8 @@ class SimpleConsumer(BaseConsumer): model_config, plugin_config, minibatch_size, + save_interval, + save_dir, ) path = model_config.pop("path") self.model = AutoModelForCausalLM.from_pretrained(path, **model_config) diff --git a/applications/ColossalChat/coati/distributed/grpo_consumer.py b/applications/ColossalChat/coati/distributed/grpo_consumer.py index 72e54b0de..eae4ff54e 100644 --- a/applications/ColossalChat/coati/distributed/grpo_consumer.py +++ b/applications/ColossalChat/coati/distributed/grpo_consumer.py @@ -33,12 +33,13 @@ class GRPOConsumer(BaseConsumer): plugin_config, minibatch_size=1, num_generations=8, - use_wandb=True, generate_config=None, grpo_config={}, - project_name=None, save_interval: int = 100, save_dir="./model", + project_name: str = None, + run_name: str = None, + wandb_group_name: str = None, ): print(f"Using GRPO config: {grpo_config}") if ( @@ -84,6 +85,9 @@ class GRPOConsumer(BaseConsumer): self.effective_sample_count = 0 self.effective_prompt_count = 0 self.total_sample_count = 0 + self.project_name = project_name + self.run_name = run_name + self.wandb_group_name = wandb_group_name self.policy_loss_fn = PolicyLoss( clip_eps_low=grpo_config.get("clip_eps_low", 0.2), @@ -134,7 +138,6 @@ class GRPOConsumer(BaseConsumer): **reward_model_kwargs, ) self.global_step = 0 - self.use_wandb = use_wandb self.lr_scheduler = CosineAnnealingWarmupLR( optimizer=self.optimizer, @@ -145,13 +148,16 @@ class GRPOConsumer(BaseConsumer): def setup(self): super().setup() - if self.use_wandb and ( - (not self.plugin.pp_size > 1 and self.rank == 0) - or (self.plugin.pp_size > 1 and self.booster.plugin.stage_manager.is_last_stage() and self.tp_rank == 0) + if (not self.plugin.pp_size > 1 and self.rank == 0) or ( + self.plugin.pp_size > 1 and self.booster.plugin.stage_manager.is_last_stage() and self.tp_rank == 0 ): - # Initialize wandb. - name = f"{self.generate_config['backend']}_bs_{self.batch_size*self.dp_size}_temp_{self.generate_config['temperature']:.01f}_top_p_{self.generate_config['top_p']:.02f}" - self.wandb_run = wandb.init(project=self.project_name, sync_tensorboard=True, dir="./wandb", name=name) + self.wandb_run = wandb.init( + project=self.project_name, + sync_tensorboard=False, + dir="./wandb", + name=self.run_name, + group=self.wandb_group_name, + ) self.policy_model, self.optimizer, _, _, self.lr_scheduler = self.booster.boost( self.policy_model, self.optimizer, lr_scheduler=self.lr_scheduler @@ -265,7 +271,6 @@ class GRPOConsumer(BaseConsumer): total_samples = all_reduce_sum(torch.sum(torch.ones_like(loss_mask, device=loss_mask.device)), self.plugin) self.effective_sample_count += effective_samples.item() self.total_sample_count += total_samples.item() - pbar.set_postfix( { "Global Step": self.global_step, @@ -506,8 +511,8 @@ class GRPOConsumer(BaseConsumer): } if self.policy_loss_fn.beta > 0: metrics["train/kl"] = self.accum_kl.item() / self.accum_count - - self.wandb_run.log(metrics) + if self.wandb_run is not None: + self.wandb_run.log(metrics) self.accum_loss.zero_() self.accum_reward.zero_() self.accum_ans_acc.zero_() @@ -521,7 +526,6 @@ class GRPOConsumer(BaseConsumer): # All gather excessive prompts index across DP ranks. excessive_prompts_idx = [idx + self.dp_rank * self.minibatch_size for idx in excessive_prompts_idx] excessive_prompts_idx = all_gather_tensors(excessive_prompts_idx, self.plugin) - return loss_scalar, excessive_prompts_idx else: return None, excessive_prompts_idx @@ -530,4 +534,5 @@ class GRPOConsumer(BaseConsumer): self.policy_model._force_wait_all_gather() model = self.policy_model.unwrap() state_dict = model.state_dict() + state_dict["consumer_global_step"] = torch.tensor([self.global_step], device=self.device) return state_dict diff --git a/applications/ColossalChat/coati/distributed/inference_backend.py b/applications/ColossalChat/coati/distributed/inference_backend.py index 8ff4b15bf..ce2a2f28c 100644 --- a/applications/ColossalChat/coati/distributed/inference_backend.py +++ b/applications/ColossalChat/coati/distributed/inference_backend.py @@ -205,7 +205,8 @@ class VLLMInferenceBackend(BaseInferenceBackend): generate_config = generate_config.copy() generate_config.update(self.FORCE_GENERATE_CONFIG) generate_config.update({"n": num_generations}) - self.generate_config = SamplingParams(**generate_config) + self.generate_config = generate_config + self.sample_params = SamplingParams(**generate_config) self.model_config = model_config self.tokenizer = tokenizer self.num_generations = num_generations @@ -219,8 +220,9 @@ class VLLMInferenceBackend(BaseInferenceBackend): micro_batch_input_ids_no_padding = [ micro_batch_input_ids[i][first_non_padding_token_idx[i] :] for i in range(micro_batch_size) ] + sample_params = kwargs.get("sample_params", self.sample_params) outputs = self.llm.generate( - prompt_token_ids=micro_batch_input_ids_no_padding, sampling_params=self.generate_config, use_tqdm=False + prompt_token_ids=micro_batch_input_ids_no_padding, sampling_params=sample_params, use_tqdm=False ) out_tokens = [] out_len = [] @@ -236,7 +238,7 @@ class VLLMInferenceBackend(BaseInferenceBackend): log_probs.append(p) # pad them - max_len = self.generate_config.max_tokens + max_len = self.sample_params.max_tokens action_mask = torch.ones(len(out_tokens), max_len, dtype=attention_mask.dtype) for i, new_token_ids in enumerate(out_tokens): @@ -266,11 +268,11 @@ class VLLMInferenceBackend(BaseInferenceBackend): "response_idx": response_idx, } - data = {k: v.view(micro_batch_size, self.num_generations, v.size(-1)) for k, v in data.items()} + data = {k: v.view(micro_batch_size, -1, v.size(-1)) for k, v in data.items()} if "gt_answer" in kwargs: # repeat gt_answer for each prompt. - data["gt_answer"] = kwargs["gt_answer"].repeat_interleave(self.num_generations, dim=1) + data["gt_answer"] = kwargs["gt_answer"].repeat_interleave(data["input_ids"].size(1), dim=1) data = {k: v.to(get_current_device()) for k, v in data.items()} return data diff --git a/applications/ColossalChat/coati/distributed/launch.py b/applications/ColossalChat/coati/distributed/launch.py index a346d1d4f..e6367d2d1 100644 --- a/applications/ColossalChat/coati/distributed/launch.py +++ b/applications/ColossalChat/coati/distributed/launch.py @@ -1,4 +1,5 @@ import copy +import uuid from typing import Any, Dict, Optional import ray @@ -34,7 +35,7 @@ def launch_distributed( inference_microbatch_size: int, train_batch_size: int, train_minibatch_size: int, - dataset_config: Dict[str, Any], + train_dataset_config: Dict[str, Any], dataloaders_config: Dict[str, Any], inference_model_config: Dict[str, Any], generate_config: Dict[str, Any], @@ -50,8 +51,11 @@ def launch_distributed( project_name: Optional[str] = None, save_interval: int = 100, save_dir: str = "./model", + eval_dataset_config: Optional[Dict[str, Any]] = None, + eval_interval: int = 100, + eval_save_dir: Optional[str] = None, + eval_generation_config: Optional[Dict[str, Any]] = None, ): - if core_algo not in ALGO_MAP: raise NotImplementedError(f"{core_algo} is not supported yet.") else: @@ -60,12 +64,15 @@ def launch_distributed( train_dp_size = get_dp_size_fast(num_consumer_procs, plugin_config) assert (inference_batch_size * num_producers) % (train_batch_size * train_dp_size) == 0 - dataset_path = dataset_config["path"] + dataset_path = train_dataset_config["path"] num_samples = get_jsonl_size_fast(dataset_path) global_inference_batch_size = inference_batch_size * num_producers num_update_per_episode = num_samples // global_inference_batch_size num_recv_per_update = inference_batch_size // inference_microbatch_size + run_name = f"{inference_backend}_bs_{train_batch_size * train_dp_size}_temp_{generate_config['temperature']:.01f}_top_p_{generate_config['top_p']:.02f}" + wandb_group_name = str(uuid.uuid4()) + procs = [] for i in range(num_producers): producer = SimpleProducer.options(num_gpus=num_proc_per_producer).remote( @@ -74,7 +81,7 @@ def launch_distributed( num_consumer_procs=num_consumer_procs, num_episodes=num_episodes, batch_size=inference_batch_size, - dataset_config=dataset_config, + train_dataset_config=train_dataset_config, dataloaders_config=dataloaders_config, model_config=inference_model_config, generate_config=generate_config, @@ -83,6 +90,14 @@ def launch_distributed( backend=inference_backend, num_generations=num_generations, consumer_plugin_config=plugin_config, + eval_dataset_config=eval_dataset_config, + eval_interval=eval_interval, + evaluation_function_type=grpo_config["reward_fn_type"], + eval_save_dir=eval_save_dir, + eval_generation_config=eval_generation_config, + project_name=project_name, + run_name=run_name, + wandb_group_name=wandb_group_name, ) procs.append(producer) generate_config_consumer = copy.deepcopy(generate_config) @@ -108,9 +123,11 @@ def launch_distributed( generate_config=generate_config_consumer, grpo_config=grpo_config, num_generations=num_generations, - project_name=project_name, save_interval=save_interval, save_dir=save_dir, + project_name=project_name, + run_name=run_name, + wandb_group_name=wandb_group_name, ) procs.append(consumer) ray.get([p.setup.remote() for p in procs]) diff --git a/applications/ColossalChat/coati/distributed/producer.py b/applications/ColossalChat/coati/distributed/producer.py index ffe3a4428..0d91f43f1 100644 --- a/applications/ColossalChat/coati/distributed/producer.py +++ b/applications/ColossalChat/coati/distributed/producer.py @@ -1,9 +1,16 @@ +import copy +import os from typing import Any, Dict, Optional import ray import ray.util.collective as cc import torch +import tqdm +import wandb from coati.dataset.loader import RawConversationDataset +from coati.distributed.reward.reward_fn import boxed_math_reward_fn, math_reward_fn +from ray.util.collective import allreduce +from ray.util.collective.types import Backend, ReduceOp from torch.utils.data import DataLoader, DistributedSampler from transformers import AutoTokenizer @@ -11,7 +18,12 @@ from colossalai.utils import get_current_device from .comm import ray_broadcast_tensor_dict from .inference_backend import BACKEND_MAP -from .utils import pre_send +from .utils import pre_send, safe_append_to_jsonl_file + +try: + from vllm import SamplingParams +except ImportError: + LLM = None class BaseProducer: @@ -22,7 +34,7 @@ class BaseProducer: num_consumer_procs: int, num_episodes: int, batch_size: int, - dataset_config: Dict[str, Any], + train_dataset_config: Dict[str, Any], dataloaders_config: Dict[str, Any], model_config: Dict[str, Any], generate_config: Dict[str, Any], @@ -30,6 +42,14 @@ class BaseProducer: microbatch_size: int = 1, backend: str = "transformers", consumer_plugin_config: Dict[str, Any] = None, + eval_dataset_config=None, + eval_interval=-1, # disable evaluation + evaluation_function_type="think_answer_tags", + eval_save_dir: str = "./eval", + project_name: str = None, + run_name: str = None, + wandb_group_name: str = None, + wandb_log_rollout_interval: int = 20, ): self.producer_idx = producer_idx self.num_producers = num_producers @@ -39,11 +59,31 @@ class BaseProducer: self.microbatch_size = microbatch_size assert batch_size % microbatch_size == 0 self.num_microbatches = batch_size // microbatch_size + self.latest_eval_step = -1 - self.dataset_config = dataset_config + self.train_dataset_config = train_dataset_config self.model_config = model_config self.generate_config = generate_config self.tokenizer_config = tokenizer_config + self.consumer_plugin_config = consumer_plugin_config + self.eval_interval = eval_interval + self.eval_save_dir = eval_save_dir + self.consumer_global_step = 0 + self.eval_mode = False + self.wandb_rollout_data = [] + self.wandb_log_rollout_interval = wandb_log_rollout_interval + self.latest_rollout_log_step = -1 + if self.producer_idx == 0: + self.wandb_run = wandb.init( + project=project_name, + sync_tensorboard=False, + dir="./wandb", + name=run_name + "_eval", + group=wandb_group_name, + ) + + if os.path.exists(self.eval_save_dir) and self.eval_interval > 0: + raise ValueError(f"Eval save dir {self.eval_save_dir} already exists. Please delete it or change the name.") # init tokenizer if tokenizer_config is None: @@ -55,13 +95,13 @@ class BaseProducer: self.tokenizer.padding_side = "left" # init dataloader - dataset_path = dataset_config.pop("path") - self.dataset = RawConversationDataset(self.tokenizer, dataset_path, **dataset_config) - self.dataloader = DataLoader( - self.dataset, + train_dataset_path = train_dataset_config.pop("path") + self.train_dataset = RawConversationDataset(self.tokenizer, train_dataset_path, **train_dataset_config) + self.train_dataloader = DataLoader( + self.train_dataset, batch_size=microbatch_size, sampler=DistributedSampler( - self.dataset, + self.train_dataset, num_replicas=num_producers, rank=producer_idx, shuffle=True, @@ -71,6 +111,36 @@ class BaseProducer: num_workers=4, drop_last=True, ) + + self.eval_dataset_config = eval_dataset_config + if self.eval_dataset_config is not None: + self.eval_dataloaders = {} + for eval_task_name in self.eval_dataset_config: + eval_dataset_path = eval_dataset_config[eval_task_name].pop("path") + eval_dataset = RawConversationDataset( + self.tokenizer, eval_dataset_path, **eval_dataset_config[eval_task_name] + ) + print(f"[P{self.producer_idx}] eval dataset {eval_task_name} size: {len(eval_dataset)}") + self.eval_dataloaders[eval_task_name] = DataLoader( + eval_dataset, + batch_size=microbatch_size, + sampler=DistributedSampler( + eval_dataset, + num_replicas=num_producers, + rank=producer_idx, + shuffle=False, + drop_last=False, + seed=42, + ), + ) + if evaluation_function_type == "think_answer_tags": + self.evaluation_function = math_reward_fn + elif evaluation_function_type == "boxed": + self.evaluation_function = boxed_math_reward_fn + else: + raise ValueError(f"Unknown evaluation function type {evaluation_function_type}") + else: + raise ValueError("eval_dataset_config is not defined") self.device = get_current_device() # init backend @@ -82,6 +152,12 @@ class BaseProducer: self.consumer_pp_size = consumer_plugin_config.get("pp_size", 1) # consumer pp size def setup(self) -> None: + cc.init_collective_group( + world_size=self.num_producers, + rank=self.producer_idx, + backend=Backend.NCCL, + group_name="producer_group", + ) cc.init_collective_group(1 + self.num_consumer_procs, 0, group_name=f"sync_data_{self.producer_idx}") if self.consumer_pp_size > 1: for i in range(self.consumer_pp_size): @@ -96,29 +172,74 @@ class BaseProducer: raise NotImplementedError def loop(self) -> None: - num_update_per_episode = len(self.dataloader) // self.num_microbatches + num_update_per_episode = len(self.train_dataloader) // self.num_microbatches num_valid_microbatches = num_update_per_episode * self.num_microbatches print( - f"[P{self.producer_idx}] num_valid_microbatches {num_valid_microbatches}, nmb: {self.num_microbatches}, dl: {len(self.dataloader)}" + f"[P{self.producer_idx}] num_valid_microbatches {num_valid_microbatches}, nmb: {self.num_microbatches}, dl: {len(self.train_dataloader)}" ) for episode in range(self.num_episodes): - self.dataloader.sampler.set_epoch(episode) - for i, batch in enumerate(self.dataloader): + self.train_dataloader.sampler.set_epoch(episode) + for i, batch in enumerate(self.train_dataloader): if i >= num_valid_microbatches: break + if self.eval_interval > 0 and self.eval_dataset_config is not None: + if ( + self.consumer_global_step - self.latest_eval_step >= self.eval_interval + and self.consumer_global_step > self.latest_eval_step + ) or self.latest_eval_step == -1: + to_log_msg = {} + self.eval_mode = True + for eval_task_name in self.eval_dataloaders: + if self.producer_idx == 0: + print( + f"[P{self.producer_idx}] Evaluate model at training step {self.consumer_global_step} on task {eval_task_name}" + ) + eval_results = [] + eval_statistics_tensor = torch.zeros((2,), dtype=torch.float32).to(self.device) + for eval_batch in tqdm.tqdm( + self.eval_dataloaders[eval_task_name], disable=self.producer_idx != 0 + ): + eval_outputs = self.rollout(**eval_batch, sample_params=self.eval_sample_params) + eval_results = eval_results + [ + self.evaluation_function( + eval_outputs["input_ids"][m][n], + eval_outputs["gt_answer"][m][n], + eval_outputs["response_idx"][m][n], + tokenizer=self.tokenizer, + eval_mode=True, + ) + for m in range(eval_outputs["input_ids"].size(0)) + for n in range(eval_outputs["input_ids"].size(1)) + ] + eval_statistics_tensor[0] += len([res for res in eval_results if res["ans_valid"] == 1]) + eval_statistics_tensor[1] += len(eval_results) + allreduce(eval_statistics_tensor, op=ReduceOp.SUM, group_name="producer_group") + to_log_msg[f"eval/{eval_task_name}"] = ( + eval_statistics_tensor[0].item() / eval_statistics_tensor[1].item() + ) + if self.producer_idx == 0: + print( + f"[P{self.producer_idx}]: Accuracy on {eval_task_name}: {to_log_msg[f'eval/{eval_task_name}']}" + ) + # save eval results + safe_append_to_jsonl_file( + os.path.join( + self.eval_save_dir, + f"{eval_task_name}_training_step_{self.consumer_global_step}.jsonl", + ), + eval_results, + ) + + if self.producer_idx == 0: + self.wandb_run.log(to_log_msg, step=self.consumer_global_step) + self.eval_mode = False + self.latest_eval_step = self.consumer_global_step outputs = self.rollout(**batch) print(f"[P{self.producer_idx}] Send data {[(k, v.shape) for k, v in outputs.items()]}") outputs["temperature"] = torch.tensor( - [ - ( - self.model.generate_config["temperature"] - if isinstance(self.model.generate_config.temperature, dict) - else self.model.generate_config.temperature - ) - ] - * outputs["input_ids"].size(0) + [self.model.generate_config["temperature"]] * outputs["input_ids"].size(0) ).to(outputs["input_ids"].device) outputs = pre_send(outputs) ray_broadcast_tensor_dict( @@ -142,6 +263,8 @@ class BaseProducer: state_dict = ray_broadcast_tensor_dict( None, self.num_producers, device=self.device, group_name=f"sync_model_{pp_idx}" ) + if "consumer_global_step" in state_dict: + self.consumer_global_step = state_dict.pop("consumer_global_step").item() self.load_state_dict(state_dict) else: print( @@ -150,6 +273,8 @@ class BaseProducer: state_dict = ray_broadcast_tensor_dict( None, self.num_producers, device=self.device, group_name="sync_model" ) + if "consumer_global_step" in state_dict: + self.consumer_global_step = state_dict.pop("consumer_global_step").item() self.load_state_dict(state_dict) del state_dict torch.cuda.empty_cache() @@ -159,13 +284,12 @@ class BaseProducer: self.model.llm.wake_up() # linear annealing for 1 episode, temperature from initial to 0.9 if episode <= 0: - ratio = 1 - (len(self.dataloader) - i) / len(self.dataloader) - if isinstance(self.model.generate_config.temperature, dict): - self.model.generate_config["temperature"] = (1 - ratio) * self.generate_config[ - "temperature" - ] + ratio * 0.9 - else: - self.model.generate_config.temperature = (1 - ratio) * self.generate_config[ + ratio = 1 - (len(self.train_dataloader) - i) / len(self.train_dataloader) + self.model.generate_config["temperature"] = (1 - ratio) * self.generate_config[ + "temperature" + ] + ratio * 0.9 + if isinstance(self.model, BACKEND_MAP["vllm"]): + self.model.sample_params.temperature = (1 - ratio) * self.generate_config[ "temperature" ] + ratio * 0.9 @@ -179,7 +303,7 @@ class SimpleProducer(BaseProducer): num_consumer_procs, num_episodes, batch_size, - dataset_config, + train_dataset_config, dataloaders_config, model_config, generate_config, @@ -188,6 +312,14 @@ class SimpleProducer(BaseProducer): backend="transformers", num_generations: int = 8, consumer_plugin_config=None, + eval_dataset_config=None, + eval_interval=-1, # disable evaluation + evaluation_function_type="think_answer_tags", + eval_save_dir: str = "./eval", + eval_generation_config={}, + project_name: str = None, + run_name: str = None, + wandb_group_name: str = None, ): super().__init__( producer_idx, @@ -195,7 +327,7 @@ class SimpleProducer(BaseProducer): num_consumer_procs, num_episodes, batch_size, - dataset_config, + train_dataset_config, dataloaders_config, model_config, generate_config, @@ -203,15 +335,43 @@ class SimpleProducer(BaseProducer): microbatch_size, backend, consumer_plugin_config, + eval_dataset_config=eval_dataset_config, + eval_interval=eval_interval, + evaluation_function_type=evaluation_function_type, + eval_save_dir=eval_save_dir, + project_name=project_name, + run_name=run_name, + wandb_group_name=wandb_group_name, ) self.model = self.backend_cls(model_config, generate_config, self.tokenizer, num_generations) + self.eval_generation_config = copy.deepcopy(self.model.generate_config) + self.eval_generation_config["n"] = 1 # use 1 generation for evaluation + self.eval_generation_config.update(eval_generation_config) + self.eval_sample_params = SamplingParams(**self.eval_generation_config) @torch.no_grad() def rollout(self, input_ids, attention_mask, **kwargs): rollouts = self.model.generate(input_ids, attention_mask, **kwargs) - if self.producer_idx == 1: - print("Rollout example:\n", self.tokenizer.decode(rollouts["input_ids"][0][0], skip_special_tokens=True)) - + if self.producer_idx == 0 and not self.eval_mode: + wandb_rollout_data = self.wandb_rollout_data + [ + [ + str(self.consumer_global_step), + str(self.tokenizer.decode(rollouts["input_ids"][0][0], skip_special_tokens=True)), + ] + ] + if ( + self.consumer_global_step - self.latest_rollout_log_step >= self.wandb_log_rollout_interval + or self.latest_rollout_log_step == -1 + ): + self.wandb_rollout_data = wandb_rollout_data + self.latest_rollout_log_step = self.consumer_global_step + self.wandb_run.log( + { + "rollout/rollout_examples": wandb.Table( + columns=["train_step", "rollout_examples"], data=wandb_rollout_data + ) + } + ) return rollouts def load_state_dict(self, state_dict): diff --git a/applications/ColossalChat/coati/distributed/reward/reward_fn.py b/applications/ColossalChat/coati/distributed/reward/reward_fn.py index b68c1a92f..14d340dc4 100644 --- a/applications/ColossalChat/coati/distributed/reward/reward_fn.py +++ b/applications/ColossalChat/coati/distributed/reward/reward_fn.py @@ -5,6 +5,7 @@ from .reward_utils import extract_boxed_solution, extract_solution, validate_res def math_reward_fn(input_ids, gt_answer, response_idx, **kwargs): tokenizer = kwargs["tokenizer"] + eval_mode = kwargs.get("eval_mode", False) soft_over_length_punishment = kwargs.get("soft_over_length_punishment", False) acc_score = 10.0 reward = torch.tensor(0.0) @@ -44,36 +45,23 @@ def math_reward_fn(input_ids, gt_answer, response_idx, **kwargs): reward = reward + length_reward - return torch.tensor([reward, format_acc, ans_acc]).to(input_ids.device) - - -def gsm8k_reward_fn(input_ids, **kwargs): - gt_answer = kwargs["gt_answer"] - tokenizer = kwargs["tokenizer"] - s, e = kwargs["response_start"], kwargs["response_end"] - reward = torch.tensor(0.0).to(input_ids.device) - if gt_answer is None: - return reward - decoded_final_answer = tokenizer.decode(input_ids[s : e + 1], skip_special_tokens=True) - final_answer, processed_str = extract_solution(decoded_final_answer) - is_valid = True - try: - int(final_answer.strip()) - except Exception: - is_valid = False - - format_valid = validate_response_structure(processed_str, kwargs["tags"]) - if not is_valid or not format_valid: - return reward + if not eval_mode: + return torch.tensor([reward, format_acc, ans_acc]).to(input_ids.device) else: - reward += 1.0 - if gt_answer.strip().replace(" ", "").lower() == final_answer.strip().replace(" ", "").lower(): - reward = reward + 9.0 - return reward + prompt = tokenizer.decode(input_ids[:s], skip_special_tokens=True) + return { + "prompt": prompt, + "prediction": decoded_final_answer, + "gold": gt_answer, + "parsed": final_answer, + "format_valid": format_acc.item(), + "ans_valid": ans_acc.item(), + } def boxed_math_reward_fn(input_ids, gt_answer, response_idx, **kwargs): tokenizer = kwargs["tokenizer"] + eval_mode = kwargs.get("eval_mode", False) soft_over_length_punishment = kwargs.get("soft_over_length_punishment", False) format_score = 0.0 acc_score = 10.0 @@ -91,7 +79,7 @@ def boxed_math_reward_fn(input_ids, gt_answer, response_idx, **kwargs): length_reward = ((max_length - cache_length) - res_length) / cache_length * acc_score if gt_answer is None: - return reward + return torch.tensor([reward, format_acc, ans_acc]).to(input_ids.device) decoded_final_answer = tokenizer.decode(input_ids[s : e + 1], skip_special_tokens=True) gt_answer = tokenizer.decode(gt_answer.squeeze(0), skip_special_tokens=True) @@ -108,5 +96,15 @@ def boxed_math_reward_fn(input_ids, gt_answer, response_idx, **kwargs): reward += acc_score reward = reward + length_reward - - return torch.tensor([reward, format_acc, ans_acc]).to(input_ids.device) + if not eval_mode: + return torch.tensor([reward, format_acc, ans_acc]).to(input_ids.device) + else: + prompt = tokenizer.decode(input_ids[:s], skip_special_tokens=True) + return { + "prompt": prompt, + "prediction": decoded_final_answer, + "gold": gt_answer, + "parsed": final_answer, + "format_valid": format_acc.item(), + "ans_valid": ans_acc.item(), + } diff --git a/applications/ColossalChat/coati/distributed/utils.py b/applications/ColossalChat/coati/distributed/utils.py index 5f7879669..a40ebbcfb 100644 --- a/applications/ColossalChat/coati/distributed/utils.py +++ b/applications/ColossalChat/coati/distributed/utils.py @@ -1,6 +1,9 @@ +import json +import os from typing import Any, Dict, List import torch +from filelock import FileLock from colossalai.shardformer.layer.loss import dist_log_prob @@ -130,3 +133,13 @@ def masked_sum(tensor: torch.Tensor, mask: torch.Tensor, dim: int = 1) -> torch. """ tensor = tensor * mask return tensor.sum(dim=dim) + + +def safe_append_to_jsonl_file(file_path, data): + with FileLock(file_path + ".lock"): + # Ensure file exists + os.makedirs(os.path.dirname(file_path), exist_ok=True) + with open(file_path, "a", encoding="utf8") as f: + for entry in data: + json_line = json.dumps(entry, ensure_ascii=False) + f.write(json_line + "\n") diff --git a/applications/ColossalChat/rl_example.py b/applications/ColossalChat/rl_example.py index f5609c890..8d1f25e74 100644 --- a/applications/ColossalChat/rl_example.py +++ b/applications/ColossalChat/rl_example.py @@ -1,4 +1,5 @@ import argparse +import json import os import ray @@ -9,7 +10,16 @@ if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("-m", "--model", type=str, default="Qwen/Qwen2.5-7B") parser.add_argument("-d", "--dataset", type=str, default="data.jsonl") - parser.add_argument("-p", "--project", type=str, default="GRPO-V3", help="Project name.") + parser.add_argument( + "-ed", + "--eval-dataset", + type=str, + default='{"eval task name":"data_eval.jsonl"}', + help="Evaluation dataset for each task, please use json format to specify the dataset for each task. \ + For example: {'task1':'data_eval_task1.jsonl', 'task2':'data_eval_task2.jsonl'}, the jsonl file should be in the same format as the training dataset. \ + The key is the task name, and the value is the path to the jsonl file", + ) + parser.add_argument("-p", "--project", type=str, default="GRPO", help="Project name.") parser.add_argument("-e", "--num-episodes", type=int, default=1, help="Number of episodes to train.") # Distributed training parameters @@ -94,11 +104,20 @@ if __name__ == "__main__": choices=["think_answer_tags", "boxed"], help="Reward type for GRPO.", ) + parser.add_argument( + "-ei", + "--eval-interval", + type=int, + default=100, + help="Interval for evaluation. Evaluate every ei training steps.", + ) # Logging/Checkpointing parameters parser.add_argument("-si", "--save-interval", type=int, default=100, help="Interval for saving checkpoints.") parser.add_argument("-sd", "--save-dir", type=str, default="./model", help="Directory for saving checkpoints.") - + parser.add_argument( + "-esd", "--eval-save-dir", type=str, default="./eval", help="Directory for saving evaluation results." + ) args = parser.parse_args() if args.train_minibatch_size is None: @@ -148,6 +167,7 @@ if __name__ == "__main__": stop_strings=[""] if args.reward_type == "think_answer_tags" else None, ) ) + eval_generation_config = {"temperature": 0.6} # used to update generation config for evaluation elif args.backend == "vllm": inference_model_config.update( dict( @@ -166,6 +186,7 @@ if __name__ == "__main__": stop=[""] if args.reward_type == "think_answer_tags" else None, ) ) + eval_generation_config = {"temperature": 0.6} # used to update generation config for evaluation else: raise ValueError(f"Unsupported backend: {args.backend}") @@ -208,7 +229,7 @@ if __name__ == "__main__": inference_microbatch_size=args.inference_microbatch_size, train_batch_size=args.train_batch_size, train_minibatch_size=args.train_minibatch_size, - dataset_config={ + train_dataset_config={ "path": args.dataset, "max_length": args.max_prompt_tokens, "system_prompt": args.system_prompt, @@ -238,4 +259,11 @@ if __name__ == "__main__": project_name=args.project, save_interval=args.save_interval, save_dir=os.path.join(args.save_dir, args.project.replace(" ", "_")), + eval_dataset_config={ + k: {"path": v, "max_length": args.max_prompt_tokens, "system_prompt": args.system_prompt} + for k, v in json.loads(args.eval_dataset).items() + }, + eval_interval=args.eval_interval, + eval_save_dir=os.path.join(args.eval_save_dir, args.project.replace(" ", "_")), + eval_generation_config=eval_generation_config, )