From 47a7dc7142586859c73e1f07bf06ee9829e09adf Mon Sep 17 00:00:00 2001 From: YeAnbang Date: Wed, 30 Apr 2025 18:13:40 +0800 Subject: [PATCH 1/8] Support evaluation during training --- .gitignore | 1 + .../coati/distributed/consumer.py | 28 +++- .../coati/distributed/grpo_consumer.py | 3 + .../coati/distributed/inference_backend.py | 10 +- .../ColossalChat/coati/distributed/launch.py | 16 +- .../coati/distributed/producer.py | 149 +++++++++++++++--- .../coati/distributed/reward/reward_fn.py | 54 +++---- .../ColossalChat/coati/distributed/utils.py | 13 ++ applications/ColossalChat/rl_example.py | 25 ++- 9 files changed, 234 insertions(+), 65 deletions(-) diff --git a/.gitignore b/.gitignore index 533450a7c..d48035a5b 100644 --- a/.gitignore +++ b/.gitignore @@ -165,3 +165,4 @@ applications/ColossalChat/logs applications/ColossalChat/tests/logs applications/ColossalChat/wandb applications/ColossalChat/model +applications/ColossalChat/eval diff --git a/applications/ColossalChat/coati/distributed/consumer.py b/applications/ColossalChat/coati/distributed/consumer.py index 9503d65eb..620ab6ff6 100644 --- a/applications/ColossalChat/coati/distributed/consumer.py +++ b/applications/ColossalChat/coati/distributed/consumer.py @@ -36,6 +36,7 @@ class BaseConsumer: minibatch_size: int = 1, save_interval: int = 100, save_dir: str = "./model", + eval_interval: int = -1, ): self.num_producers = num_producers self.num_episodes = num_episodes @@ -51,6 +52,7 @@ class BaseConsumer: self.save_dir = save_dir assert batch_size % minibatch_size == 0, "batch_size should be divisible by microbatch_size" self.num_microbatches = batch_size // minibatch_size + self.eval_interval = eval_interval self.model_config = model_config self.plugin_config = plugin_config @@ -92,8 +94,10 @@ class BaseConsumer: if self.rank == 0: cc.init_collective_group(self.num_producers + 1, self.num_producers, group_name="sync_model") - self.buffer = [] + for i in range(self.num_producers): + cc.init_collective_group(self.world_size + 1, self.rank + 1, group_name=f"sync_eval_statistics_{i}") + self.buffer = [] self.recv_cnt = 0 def state_dict(self) -> Dict[str, torch.Tensor]: @@ -110,6 +114,24 @@ class BaseConsumer: with tqdm(range(self.num_update_per_episode), desc=f"Episode {episode}", disable=self.rank != 0) as pbar: for step in pbar: i = 0 + if self.eval_interval > 0 and step % self.eval_interval == 0: + eval_statistics = None + for r in range(self.num_producers): + print(f"[T{dist.get_rank()}] Recv eval result episode {episode} step {step} from {r}") + local_eval_result = ray_broadcast_tensor_dict( + None, src=0, device=self.device, group_name=f"sync_eval_statistics_{r}" + ) + if eval_statistics is None: + eval_statistics = local_eval_result + else: + eval_statistics = { + k: eval_statistics[k] + local_eval_result[k] for k in eval_statistics + } + eval_statistics = {k: (v[0] / v[1]).item() for k, v in eval_statistics.items()} + if dist.get_rank() == 0: + if hasattr(self, "wandb_run") and hasattr(self, "global_step"): + self.wandb_run.log(eval_statistics, step=self.global_step) + print(f"Eval statistics: {eval_statistics}") for _ in range(self.num_recv_per_update): # receive data from producers for r in range(self.num_producers): @@ -192,6 +214,7 @@ class SimpleConsumer(BaseConsumer): minibatch_size=1, save_interval: int = 100, save_dir="./model", + eval_interval: int = -1, ): super().__init__( num_producers, @@ -206,6 +229,9 @@ class SimpleConsumer(BaseConsumer): model_config, plugin_config, minibatch_size, + save_interval, + save_dir, + eval_interval, ) path = model_config.pop("path") self.model = AutoModelForCausalLM.from_pretrained(path, **model_config) diff --git a/applications/ColossalChat/coati/distributed/grpo_consumer.py b/applications/ColossalChat/coati/distributed/grpo_consumer.py index 877ff98ec..0336125d1 100644 --- a/applications/ColossalChat/coati/distributed/grpo_consumer.py +++ b/applications/ColossalChat/coati/distributed/grpo_consumer.py @@ -40,6 +40,7 @@ class GRPOConsumer(BaseConsumer): project_name=None, save_interval: int = 100, save_dir="./model", + eval_interval: int = -1, ): print(f"Using GRPO config: {grpo_config}") if grpo_config.get("loss_variation", "sample_level") == "token_level": @@ -72,6 +73,7 @@ class GRPOConsumer(BaseConsumer): minibatch_size, save_interval=save_interval, save_dir=save_dir, + eval_interval=eval_interval, ) path = model_config.pop("path") self.policy_model = AutoModelForCausalLM.from_pretrained(path, **model_config) @@ -528,4 +530,5 @@ class GRPOConsumer(BaseConsumer): self.policy_model._force_wait_all_gather() model = self.policy_model.unwrap() state_dict = model.state_dict() + state_dict["consumer_global_step"] = torch.tensor([self.global_step], device=self.device) return state_dict diff --git a/applications/ColossalChat/coati/distributed/inference_backend.py b/applications/ColossalChat/coati/distributed/inference_backend.py index 8ff4b15bf..64dfe5cfd 100644 --- a/applications/ColossalChat/coati/distributed/inference_backend.py +++ b/applications/ColossalChat/coati/distributed/inference_backend.py @@ -205,7 +205,8 @@ class VLLMInferenceBackend(BaseInferenceBackend): generate_config = generate_config.copy() generate_config.update(self.FORCE_GENERATE_CONFIG) generate_config.update({"n": num_generations}) - self.generate_config = SamplingParams(**generate_config) + self.generate_config = generate_config + self.sample_params = SamplingParams(**generate_config) self.model_config = model_config self.tokenizer = tokenizer self.num_generations = num_generations @@ -219,8 +220,9 @@ class VLLMInferenceBackend(BaseInferenceBackend): micro_batch_input_ids_no_padding = [ micro_batch_input_ids[i][first_non_padding_token_idx[i] :] for i in range(micro_batch_size) ] + sample_params = kwargs.get("sample_params", self.sample_params) outputs = self.llm.generate( - prompt_token_ids=micro_batch_input_ids_no_padding, sampling_params=self.generate_config, use_tqdm=False + prompt_token_ids=micro_batch_input_ids_no_padding, sampling_params=sample_params, use_tqdm=False ) out_tokens = [] out_len = [] @@ -266,11 +268,11 @@ class VLLMInferenceBackend(BaseInferenceBackend): "response_idx": response_idx, } - data = {k: v.view(micro_batch_size, self.num_generations, v.size(-1)) for k, v in data.items()} + data = {k: v.view(micro_batch_size, -1, v.size(-1)) for k, v in data.items()} if "gt_answer" in kwargs: # repeat gt_answer for each prompt. - data["gt_answer"] = kwargs["gt_answer"].repeat_interleave(self.num_generations, dim=1) + data["gt_answer"] = kwargs["gt_answer"].repeat_interleave(data["input_ids"].size(1), dim=1) data = {k: v.to(get_current_device()) for k, v in data.items()} return data diff --git a/applications/ColossalChat/coati/distributed/launch.py b/applications/ColossalChat/coati/distributed/launch.py index a346d1d4f..d1b6158e5 100644 --- a/applications/ColossalChat/coati/distributed/launch.py +++ b/applications/ColossalChat/coati/distributed/launch.py @@ -34,7 +34,7 @@ def launch_distributed( inference_microbatch_size: int, train_batch_size: int, train_minibatch_size: int, - dataset_config: Dict[str, Any], + train_dataset_config: Dict[str, Any], dataloaders_config: Dict[str, Any], inference_model_config: Dict[str, Any], generate_config: Dict[str, Any], @@ -50,6 +50,9 @@ def launch_distributed( project_name: Optional[str] = None, save_interval: int = 100, save_dir: str = "./model", + eval_dataset_config: Optional[Dict[str, Any]] = None, + eval_interval: int = 100, + eval_save_dir: Optional[str] = None, ): if core_algo not in ALGO_MAP: @@ -60,9 +63,9 @@ def launch_distributed( train_dp_size = get_dp_size_fast(num_consumer_procs, plugin_config) assert (inference_batch_size * num_producers) % (train_batch_size * train_dp_size) == 0 - dataset_path = dataset_config["path"] + dataset_path = train_dataset_config["path"] num_samples = get_jsonl_size_fast(dataset_path) - global_inference_batch_size = inference_batch_size * num_producers + global_inference_batch_size = inference_batch_size * num_producers # TODO: this doesn't support TP on producer num_update_per_episode = num_samples // global_inference_batch_size num_recv_per_update = inference_batch_size // inference_microbatch_size @@ -74,7 +77,7 @@ def launch_distributed( num_consumer_procs=num_consumer_procs, num_episodes=num_episodes, batch_size=inference_batch_size, - dataset_config=dataset_config, + train_dataset_config=train_dataset_config, dataloaders_config=dataloaders_config, model_config=inference_model_config, generate_config=generate_config, @@ -83,6 +86,10 @@ def launch_distributed( backend=inference_backend, num_generations=num_generations, consumer_plugin_config=plugin_config, + eval_dataset_config=eval_dataset_config, + eval_interval=eval_interval, + evaluation_function_type=grpo_config["reward_fn_type"], + eval_save_dir=eval_save_dir, ) procs.append(producer) generate_config_consumer = copy.deepcopy(generate_config) @@ -111,6 +118,7 @@ def launch_distributed( project_name=project_name, save_interval=save_interval, save_dir=save_dir, + eval_interval=eval_interval, ) procs.append(consumer) ray.get([p.setup.remote() for p in procs]) diff --git a/applications/ColossalChat/coati/distributed/producer.py b/applications/ColossalChat/coati/distributed/producer.py index ffe3a4428..483ad74b3 100644 --- a/applications/ColossalChat/coati/distributed/producer.py +++ b/applications/ColossalChat/coati/distributed/producer.py @@ -1,9 +1,13 @@ +import copy +import os from typing import Any, Dict, Optional import ray import ray.util.collective as cc import torch +import tqdm from coati.dataset.loader import RawConversationDataset +from coati.distributed.reward.reward_fn import boxed_math_reward_fn, math_reward_fn from torch.utils.data import DataLoader, DistributedSampler from transformers import AutoTokenizer @@ -11,7 +15,12 @@ from colossalai.utils import get_current_device from .comm import ray_broadcast_tensor_dict from .inference_backend import BACKEND_MAP -from .utils import pre_send +from .utils import pre_send, safe_write_jsonl + +try: + from vllm import SamplingParams +except ImportError: + LLM = None class BaseProducer: @@ -22,7 +31,7 @@ class BaseProducer: num_consumer_procs: int, num_episodes: int, batch_size: int, - dataset_config: Dict[str, Any], + train_dataset_config: Dict[str, Any], dataloaders_config: Dict[str, Any], model_config: Dict[str, Any], generate_config: Dict[str, Any], @@ -30,6 +39,10 @@ class BaseProducer: microbatch_size: int = 1, backend: str = "transformers", consumer_plugin_config: Dict[str, Any] = None, + eval_dataset_config=None, + eval_interval=-1, # disable evaluation + evaluation_function_type="think_answer_tags", + eval_save_dir: str = "./eval", ): self.producer_idx = producer_idx self.num_producers = num_producers @@ -40,10 +53,17 @@ class BaseProducer: assert batch_size % microbatch_size == 0 self.num_microbatches = batch_size // microbatch_size - self.dataset_config = dataset_config + self.train_dataset_config = train_dataset_config self.model_config = model_config self.generate_config = generate_config self.tokenizer_config = tokenizer_config + self.consumer_plugin_config = consumer_plugin_config + self.eval_interval = eval_interval + self.eval_save_dir = eval_save_dir + self.consumer_global_step = 0 + + if os.path.exists(self.eval_save_dir): + raise ValueError(f"Eval save dir {self.eval_save_dir} already exists. Please delete it or change the name.") # init tokenizer if tokenizer_config is None: @@ -55,13 +75,13 @@ class BaseProducer: self.tokenizer.padding_side = "left" # init dataloader - dataset_path = dataset_config.pop("path") - self.dataset = RawConversationDataset(self.tokenizer, dataset_path, **dataset_config) - self.dataloader = DataLoader( - self.dataset, + train_dataset_path = train_dataset_config.pop("path") + self.train_dataset = RawConversationDataset(self.tokenizer, train_dataset_path, **train_dataset_config) + self.train_dataloader = DataLoader( + self.train_dataset, batch_size=microbatch_size, sampler=DistributedSampler( - self.dataset, + self.train_dataset, num_replicas=num_producers, rank=producer_idx, shuffle=True, @@ -71,6 +91,36 @@ class BaseProducer: num_workers=4, drop_last=True, ) + + self.eval_dataset_config = eval_dataset_config + if self.eval_dataset_config is not None: + self.eval_dataloaders = {} + for eval_task_name in self.eval_dataset_config: + eval_dataset_path = eval_dataset_config[eval_task_name].pop("path") + eval_dataset = RawConversationDataset( + self.tokenizer, eval_dataset_path, **eval_dataset_config[eval_task_name] + ) + print(f"[P{self.producer_idx}] eval dataset {eval_task_name} size: {len(eval_dataset)}") + self.eval_dataloaders[eval_task_name] = DataLoader( + eval_dataset, + batch_size=microbatch_size, + sampler=DistributedSampler( + eval_dataset, + num_replicas=num_producers, + rank=producer_idx, + shuffle=False, + drop_last=False, + seed=42, + ), + ) + if evaluation_function_type == "think_answer_tags": + self.evaluation_function = math_reward_fn + elif evaluation_function_type == "boxed": + self.evaluation_function = boxed_math_reward_fn + else: + raise ValueError(f"Unknown evaluation function type {evaluation_function_type}") + else: + raise ValueError("eval_dataset_config is not defined") self.device = get_current_device() # init backend @@ -88,6 +138,7 @@ class BaseProducer: cc.init_collective_group(self.num_producers + 1, self.producer_idx, group_name=f"sync_model_{i}") else: cc.init_collective_group(self.num_producers + 1, self.producer_idx, group_name="sync_model") + cc.init_collective_group(1 + self.num_consumer_procs, 0, group_name=f"sync_eval_statistics_{self.producer_idx}") def rollout(self, input_ids: torch.Tensor, attention_mask: torch.Tensor, **kwargs) -> Dict[str, torch.Tensor]: raise NotImplementedError @@ -96,29 +147,64 @@ class BaseProducer: raise NotImplementedError def loop(self) -> None: - num_update_per_episode = len(self.dataloader) // self.num_microbatches + num_update_per_episode = len(self.train_dataloader) // self.num_microbatches num_valid_microbatches = num_update_per_episode * self.num_microbatches print( - f"[P{self.producer_idx}] num_valid_microbatches {num_valid_microbatches}, nmb: {self.num_microbatches}, dl: {len(self.dataloader)}" + f"[P{self.producer_idx}] num_valid_microbatches {num_valid_microbatches}, nmb: {self.num_microbatches}, dl: {len(self.train_dataloader)}" ) for episode in range(self.num_episodes): - self.dataloader.sampler.set_epoch(episode) - for i, batch in enumerate(self.dataloader): + self.train_dataloader.sampler.set_epoch(episode) + for i, batch in enumerate(self.train_dataloader): if i >= num_valid_microbatches: break + if self.eval_interval > 0 and self.eval_dataset_config is not None: + if i % self.eval_interval == 0: + eval_statistics = {} + for eval_task_name in self.eval_dataloaders: + print( + f"[P{self.producer_idx}] Evaluate episode {episode} step {i} on task {eval_task_name}" + ) + eval_results = [] + eval_statistics[eval_task_name] = torch.zeros(2, device=self.device) + for eval_batch in tqdm.tqdm( + self.eval_dataloaders[eval_task_name], disable=self.producer_idx != 0 + ): + eval_outputs = self.rollout(**eval_batch, sample_params=self.eval_sample_params) + eval_results = eval_results + [ + self.evaluation_function( + eval_outputs["input_ids"][m][n], + eval_outputs["gt_answer"][m][n], + eval_outputs["response_idx"][m][n], + tokenizer=self.tokenizer, + eval_mode=True, + ) + for m in range(eval_outputs["input_ids"].size(0)) + for n in range(eval_outputs["input_ids"].size(1)) + ] + eval_statistics[eval_task_name][0] += len( + [res for res in eval_results if res["ans_valid"] == 1] + ) + eval_statistics[eval_task_name][1] += len(eval_results) + # save eval results + result_file_name = os.path.join( + self.eval_save_dir, + f"{eval_task_name}_episode_{episode}_step_{self.consumer_global_step}.jsonl", + ) + # delete the file if it exists + safe_write_jsonl(result_file_name, eval_results) + print(f"[P{self.producer_idx}] Send eval statistics episode {episode} step {i}") + ray_broadcast_tensor_dict( + eval_statistics, + src=0, + device=self.device, + group_name=f"sync_eval_statistics_{self.producer_idx}", + ) outputs = self.rollout(**batch) print(f"[P{self.producer_idx}] Send data {[(k, v.shape) for k, v in outputs.items()]}") outputs["temperature"] = torch.tensor( - [ - ( - self.model.generate_config["temperature"] - if isinstance(self.model.generate_config.temperature, dict) - else self.model.generate_config.temperature - ) - ] - * outputs["input_ids"].size(0) + [self.model.generate_config["temperature"]] * outputs["input_ids"].size(0) ).to(outputs["input_ids"].device) outputs = pre_send(outputs) ray_broadcast_tensor_dict( @@ -150,6 +236,8 @@ class BaseProducer: state_dict = ray_broadcast_tensor_dict( None, self.num_producers, device=self.device, group_name="sync_model" ) + if "consumer_global_step" in state_dict: + self.consumer_global_step = state_dict.pop("consumer_global_step").item() self.load_state_dict(state_dict) del state_dict torch.cuda.empty_cache() @@ -159,7 +247,7 @@ class BaseProducer: self.model.llm.wake_up() # linear annealing for 1 episode, temperature from initial to 0.9 if episode <= 0: - ratio = 1 - (len(self.dataloader) - i) / len(self.dataloader) + ratio = 1 - (len(self.train_dataloader) - i) / len(self.train_dataloader) if isinstance(self.model.generate_config.temperature, dict): self.model.generate_config["temperature"] = (1 - ratio) * self.generate_config[ "temperature" @@ -179,7 +267,7 @@ class SimpleProducer(BaseProducer): num_consumer_procs, num_episodes, batch_size, - dataset_config, + train_dataset_config, dataloaders_config, model_config, generate_config, @@ -188,6 +276,10 @@ class SimpleProducer(BaseProducer): backend="transformers", num_generations: int = 8, consumer_plugin_config=None, + eval_dataset_config=None, + eval_interval=-1, # disable evaluation + evaluation_function_type="think_answer_tags", + eval_save_dir: str = "./eval", ): super().__init__( producer_idx, @@ -195,7 +287,7 @@ class SimpleProducer(BaseProducer): num_consumer_procs, num_episodes, batch_size, - dataset_config, + train_dataset_config, dataloaders_config, model_config, generate_config, @@ -203,14 +295,21 @@ class SimpleProducer(BaseProducer): microbatch_size, backend, consumer_plugin_config, + eval_dataset_config=eval_dataset_config, + eval_interval=eval_interval, + evaluation_function_type=evaluation_function_type, + eval_save_dir=eval_save_dir, ) self.model = self.backend_cls(model_config, generate_config, self.tokenizer, num_generations) + self.eval_generation_config = copy.deepcopy(self.model.generate_config) + self.eval_generation_config["n"] = 1 # use 1 generation for evaluation + self.eval_sample_params = SamplingParams(**self.eval_generation_config) @torch.no_grad() def rollout(self, input_ids, attention_mask, **kwargs): rollouts = self.model.generate(input_ids, attention_mask, **kwargs) - if self.producer_idx == 1: - print("Rollout example:\n", self.tokenizer.decode(rollouts["input_ids"][0][0], skip_special_tokens=True)) + # if self.producer_idx == 1: + # print("Rollout example:\n", self.tokenizer.decode(rollouts["input_ids"][0][0], skip_special_tokens=True)) return rollouts diff --git a/applications/ColossalChat/coati/distributed/reward/reward_fn.py b/applications/ColossalChat/coati/distributed/reward/reward_fn.py index b68c1a92f..14d340dc4 100644 --- a/applications/ColossalChat/coati/distributed/reward/reward_fn.py +++ b/applications/ColossalChat/coati/distributed/reward/reward_fn.py @@ -5,6 +5,7 @@ from .reward_utils import extract_boxed_solution, extract_solution, validate_res def math_reward_fn(input_ids, gt_answer, response_idx, **kwargs): tokenizer = kwargs["tokenizer"] + eval_mode = kwargs.get("eval_mode", False) soft_over_length_punishment = kwargs.get("soft_over_length_punishment", False) acc_score = 10.0 reward = torch.tensor(0.0) @@ -44,36 +45,23 @@ def math_reward_fn(input_ids, gt_answer, response_idx, **kwargs): reward = reward + length_reward - return torch.tensor([reward, format_acc, ans_acc]).to(input_ids.device) - - -def gsm8k_reward_fn(input_ids, **kwargs): - gt_answer = kwargs["gt_answer"] - tokenizer = kwargs["tokenizer"] - s, e = kwargs["response_start"], kwargs["response_end"] - reward = torch.tensor(0.0).to(input_ids.device) - if gt_answer is None: - return reward - decoded_final_answer = tokenizer.decode(input_ids[s : e + 1], skip_special_tokens=True) - final_answer, processed_str = extract_solution(decoded_final_answer) - is_valid = True - try: - int(final_answer.strip()) - except Exception: - is_valid = False - - format_valid = validate_response_structure(processed_str, kwargs["tags"]) - if not is_valid or not format_valid: - return reward + if not eval_mode: + return torch.tensor([reward, format_acc, ans_acc]).to(input_ids.device) else: - reward += 1.0 - if gt_answer.strip().replace(" ", "").lower() == final_answer.strip().replace(" ", "").lower(): - reward = reward + 9.0 - return reward + prompt = tokenizer.decode(input_ids[:s], skip_special_tokens=True) + return { + "prompt": prompt, + "prediction": decoded_final_answer, + "gold": gt_answer, + "parsed": final_answer, + "format_valid": format_acc.item(), + "ans_valid": ans_acc.item(), + } def boxed_math_reward_fn(input_ids, gt_answer, response_idx, **kwargs): tokenizer = kwargs["tokenizer"] + eval_mode = kwargs.get("eval_mode", False) soft_over_length_punishment = kwargs.get("soft_over_length_punishment", False) format_score = 0.0 acc_score = 10.0 @@ -91,7 +79,7 @@ def boxed_math_reward_fn(input_ids, gt_answer, response_idx, **kwargs): length_reward = ((max_length - cache_length) - res_length) / cache_length * acc_score if gt_answer is None: - return reward + return torch.tensor([reward, format_acc, ans_acc]).to(input_ids.device) decoded_final_answer = tokenizer.decode(input_ids[s : e + 1], skip_special_tokens=True) gt_answer = tokenizer.decode(gt_answer.squeeze(0), skip_special_tokens=True) @@ -108,5 +96,15 @@ def boxed_math_reward_fn(input_ids, gt_answer, response_idx, **kwargs): reward += acc_score reward = reward + length_reward - - return torch.tensor([reward, format_acc, ans_acc]).to(input_ids.device) + if not eval_mode: + return torch.tensor([reward, format_acc, ans_acc]).to(input_ids.device) + else: + prompt = tokenizer.decode(input_ids[:s], skip_special_tokens=True) + return { + "prompt": prompt, + "prediction": decoded_final_answer, + "gold": gt_answer, + "parsed": final_answer, + "format_valid": format_acc.item(), + "ans_valid": ans_acc.item(), + } diff --git a/applications/ColossalChat/coati/distributed/utils.py b/applications/ColossalChat/coati/distributed/utils.py index 5f7879669..fb2a0dfd3 100644 --- a/applications/ColossalChat/coati/distributed/utils.py +++ b/applications/ColossalChat/coati/distributed/utils.py @@ -1,6 +1,9 @@ +import json +import os from typing import Any, Dict, List import torch +from filelock import FileLock from colossalai.shardformer.layer.loss import dist_log_prob @@ -130,3 +133,13 @@ def masked_sum(tensor: torch.Tensor, mask: torch.Tensor, dim: int = 1) -> torch. """ tensor = tensor * mask return tensor.sum(dim=dim) + + +def safe_write_jsonl(file_path, data): + with FileLock(file_path + ".lock"): + # Ensure file exists + os.makedirs(os.path.dirname(file_path), exist_ok=True) + with open(file_path, "a", encoding="utf8") as f: + for entry in data: + json_line = json.dumps(entry, ensure_ascii=False) + f.write(json_line + "\n") diff --git a/applications/ColossalChat/rl_example.py b/applications/ColossalChat/rl_example.py index 788e60c2e..b2fafd42e 100644 --- a/applications/ColossalChat/rl_example.py +++ b/applications/ColossalChat/rl_example.py @@ -1,4 +1,5 @@ import argparse +import json import os import ray @@ -8,7 +9,16 @@ from coati.distributed.launch import launch_distributed if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("-m", "--model", type=str, default="Qwen/Qwen2.5-7B") - parser.add_argument("-d", "--dataset", type=str, default="data.jsonl") + parser.add_argument("-d", "--dataset", type=str, default="data_train.jsonl") + parser.add_argument( + "-ed", + "--eval-dataset", + type=str, + default='{"eval task name":"data_eval.jsonl"}', + help="Evaluation dataset for each task, please use json format to specify the dataset for each task. \ + For example: {'task1':'data_eval_task1.jsonl', 'task2':'data_eval_task2.jsonl'}, the jsonl file should be in the same format as the training dataset. \ + The key is the task name, and the value is the path to the jsonl file", + ) parser.add_argument("-p", "--project", type=str, default="GRPO", help="Project name.") parser.add_argument("-e", "--num-episodes", type=int, default=1, help="Number of episodes to train.") @@ -94,11 +104,14 @@ if __name__ == "__main__": choices=["think_answer_tags", "boxed"], help="Reward type for GRPO.", ) + parser.add_argument("-ei", "--eval-interval", type=int, default=100, help="Interval for evaluation.") # Logging/Checkpointing parameters parser.add_argument("-si", "--save-interval", type=int, default=100, help="Interval for saving checkpoints.") parser.add_argument("-sd", "--save-dir", type=str, default="./model", help="Directory for saving checkpoints.") - + parser.add_argument( + "-esd", "--eval-save-dir", type=str, default="./eval", help="Directory for saving evaluation results." + ) args = parser.parse_args() if args.train_minibatch_size is None: @@ -208,7 +221,7 @@ if __name__ == "__main__": inference_microbatch_size=args.inference_microbatch_size, train_batch_size=args.train_batch_size, train_minibatch_size=args.train_minibatch_size, - dataset_config={ + train_dataset_config={ "path": args.dataset, "max_length": args.max_prompt_tokens, "system_prompt": args.system_prompt, @@ -238,4 +251,10 @@ if __name__ == "__main__": project_name=args.project, save_interval=args.save_interval, save_dir=os.path.join(args.save_dir, args.project.replace(" ", "_")), + eval_dataset_config={ + k: {"path": v, "max_length": args.max_prompt_tokens, "system_prompt": args.system_prompt} + for k, v in json.loads(args.eval_dataset).items() + }, + eval_interval=args.eval_interval, + eval_save_dir=os.path.join(args.eval_save_dir, args.project.replace(" ", "_")), ) From 50070c1e84a36f11d21efddd3c6828cd32b56f16 Mon Sep 17 00:00:00 2001 From: YeAnbang Date: Wed, 14 May 2025 18:10:57 +0800 Subject: [PATCH 2/8] move logging to producer --- .../coati/distributed/consumer.py | 25 ------ .../coati/distributed/grpo_consumer.py | 30 ++++--- .../coati/distributed/inference_backend.py | 2 +- .../ColossalChat/coati/distributed/launch.py | 14 ++- .../coati/distributed/producer.py | 86 +++++++++++++------ .../ColossalChat/coati/distributed/utils.py | 2 +- applications/ColossalChat/rl_example.py | 3 + 7 files changed, 92 insertions(+), 70 deletions(-) diff --git a/applications/ColossalChat/coati/distributed/consumer.py b/applications/ColossalChat/coati/distributed/consumer.py index 620ab6ff6..800a3c799 100644 --- a/applications/ColossalChat/coati/distributed/consumer.py +++ b/applications/ColossalChat/coati/distributed/consumer.py @@ -36,7 +36,6 @@ class BaseConsumer: minibatch_size: int = 1, save_interval: int = 100, save_dir: str = "./model", - eval_interval: int = -1, ): self.num_producers = num_producers self.num_episodes = num_episodes @@ -52,7 +51,6 @@ class BaseConsumer: self.save_dir = save_dir assert batch_size % minibatch_size == 0, "batch_size should be divisible by microbatch_size" self.num_microbatches = batch_size // minibatch_size - self.eval_interval = eval_interval self.model_config = model_config self.plugin_config = plugin_config @@ -94,9 +92,6 @@ class BaseConsumer: if self.rank == 0: cc.init_collective_group(self.num_producers + 1, self.num_producers, group_name="sync_model") - for i in range(self.num_producers): - cc.init_collective_group(self.world_size + 1, self.rank + 1, group_name=f"sync_eval_statistics_{i}") - self.buffer = [] self.recv_cnt = 0 @@ -114,24 +109,6 @@ class BaseConsumer: with tqdm(range(self.num_update_per_episode), desc=f"Episode {episode}", disable=self.rank != 0) as pbar: for step in pbar: i = 0 - if self.eval_interval > 0 and step % self.eval_interval == 0: - eval_statistics = None - for r in range(self.num_producers): - print(f"[T{dist.get_rank()}] Recv eval result episode {episode} step {step} from {r}") - local_eval_result = ray_broadcast_tensor_dict( - None, src=0, device=self.device, group_name=f"sync_eval_statistics_{r}" - ) - if eval_statistics is None: - eval_statistics = local_eval_result - else: - eval_statistics = { - k: eval_statistics[k] + local_eval_result[k] for k in eval_statistics - } - eval_statistics = {k: (v[0] / v[1]).item() for k, v in eval_statistics.items()} - if dist.get_rank() == 0: - if hasattr(self, "wandb_run") and hasattr(self, "global_step"): - self.wandb_run.log(eval_statistics, step=self.global_step) - print(f"Eval statistics: {eval_statistics}") for _ in range(self.num_recv_per_update): # receive data from producers for r in range(self.num_producers): @@ -214,7 +191,6 @@ class SimpleConsumer(BaseConsumer): minibatch_size=1, save_interval: int = 100, save_dir="./model", - eval_interval: int = -1, ): super().__init__( num_producers, @@ -231,7 +207,6 @@ class SimpleConsumer(BaseConsumer): minibatch_size, save_interval, save_dir, - eval_interval, ) path = model_config.pop("path") self.model = AutoModelForCausalLM.from_pretrained(path, **model_config) diff --git a/applications/ColossalChat/coati/distributed/grpo_consumer.py b/applications/ColossalChat/coati/distributed/grpo_consumer.py index 0336125d1..f301345b6 100644 --- a/applications/ColossalChat/coati/distributed/grpo_consumer.py +++ b/applications/ColossalChat/coati/distributed/grpo_consumer.py @@ -34,13 +34,13 @@ class GRPOConsumer(BaseConsumer): plugin_config, minibatch_size=1, num_generations=8, - use_wandb=True, generate_config=None, grpo_config={}, - project_name=None, save_interval: int = 100, save_dir="./model", - eval_interval: int = -1, + project_name: str = None, + run_name: str = None, + wandb_group_name: str = None, ): print(f"Using GRPO config: {grpo_config}") if grpo_config.get("loss_variation", "sample_level") == "token_level": @@ -73,7 +73,6 @@ class GRPOConsumer(BaseConsumer): minibatch_size, save_interval=save_interval, save_dir=save_dir, - eval_interval=eval_interval, ) path = model_config.pop("path") self.policy_model = AutoModelForCausalLM.from_pretrained(path, **model_config) @@ -93,6 +92,9 @@ class GRPOConsumer(BaseConsumer): self.project_name = project_name self.effective_sample_count = 0 self.total_sample_count = 0 + self.project_name = project_name + self.run_name = run_name + self.wandb_group_name = wandb_group_name self.policy_loss_fn = PolicyLoss( clip_eps_low=grpo_config.get("clip_eps_low", 0.2), @@ -143,7 +145,6 @@ class GRPOConsumer(BaseConsumer): **reward_model_kwargs, ) self.global_step = 0 - self.use_wandb = use_wandb self.lr_scheduler = CosineAnnealingWarmupLR( optimizer=self.optimizer, @@ -154,13 +155,16 @@ class GRPOConsumer(BaseConsumer): def setup(self): super().setup() - if self.use_wandb and ( - (not self.plugin.pp_size > 1 and self.rank == 0) - or (self.plugin.pp_size > 1 and self.booster.plugin.stage_manager.is_last_stage() and self.tp_rank == 0) + if (not self.plugin.pp_size > 1 and self.rank == 0) or ( + self.plugin.pp_size > 1 and self.booster.plugin.stage_manager.is_last_stage() and self.tp_rank == 0 ): - # Initialize wandb. - name = f"{self.generate_config['backend']}_bs_{self.batch_size*self.dp_size}_temp_{self.generate_config['temperature']:.01f}_top_p_{self.generate_config['top_p']:.02f}" - self.wandb_run = wandb.init(project=self.project_name, sync_tensorboard=True, dir="./wandb", name=name) + self.wandb_run = wandb.init( + project=self.project_name, + sync_tensorboard=True, + dir="./wandb", + name=self.run_name, + group=self.wandb_group_name, + ) self.policy_model, self.optimizer, _, _, self.lr_scheduler = self.booster.boost( self.policy_model, self.optimizer, lr_scheduler=self.lr_scheduler @@ -512,8 +516,8 @@ class GRPOConsumer(BaseConsumer): } if self.policy_loss_fn.beta > 0: metrics["train/kl"] = self.accum_kl.item() / self.accum_count - - self.wandb_run.log(metrics) + if self.wandb_run is not None: + self.wandb_run.log(metrics) self.accum_loss.zero_() self.accum_reward.zero_() self.accum_ans_acc.zero_() diff --git a/applications/ColossalChat/coati/distributed/inference_backend.py b/applications/ColossalChat/coati/distributed/inference_backend.py index 64dfe5cfd..ce2a2f28c 100644 --- a/applications/ColossalChat/coati/distributed/inference_backend.py +++ b/applications/ColossalChat/coati/distributed/inference_backend.py @@ -238,7 +238,7 @@ class VLLMInferenceBackend(BaseInferenceBackend): log_probs.append(p) # pad them - max_len = self.generate_config.max_tokens + max_len = self.sample_params.max_tokens action_mask = torch.ones(len(out_tokens), max_len, dtype=attention_mask.dtype) for i, new_token_ids in enumerate(out_tokens): diff --git a/applications/ColossalChat/coati/distributed/launch.py b/applications/ColossalChat/coati/distributed/launch.py index d1b6158e5..cb5b6e4e2 100644 --- a/applications/ColossalChat/coati/distributed/launch.py +++ b/applications/ColossalChat/coati/distributed/launch.py @@ -1,4 +1,5 @@ import copy +import uuid from typing import Any, Dict, Optional import ray @@ -53,6 +54,7 @@ def launch_distributed( eval_dataset_config: Optional[Dict[str, Any]] = None, eval_interval: int = 100, eval_save_dir: Optional[str] = None, + eval_generation_config: Optional[Dict[str, Any]] = None, ): if core_algo not in ALGO_MAP: @@ -69,6 +71,9 @@ def launch_distributed( num_update_per_episode = num_samples // global_inference_batch_size num_recv_per_update = inference_batch_size // inference_microbatch_size + run_name = f"{inference_backend}_bs_{train_batch_size * train_dp_size}_temp_{generate_config['temperature']:.01f}_top_p_{generate_config['top_p']:.02f}" + wandb_group_name = str(uuid.uuid4()) + procs = [] for i in range(num_producers): producer = SimpleProducer.options(num_gpus=num_proc_per_producer).remote( @@ -90,6 +95,10 @@ def launch_distributed( eval_interval=eval_interval, evaluation_function_type=grpo_config["reward_fn_type"], eval_save_dir=eval_save_dir, + eval_generation_config=eval_generation_config, + project_name=project_name, + run_name=run_name, + wandb_group_name=wandb_group_name, ) procs.append(producer) generate_config_consumer = copy.deepcopy(generate_config) @@ -115,10 +124,11 @@ def launch_distributed( generate_config=generate_config_consumer, grpo_config=grpo_config, num_generations=num_generations, - project_name=project_name, save_interval=save_interval, save_dir=save_dir, - eval_interval=eval_interval, + project_name=project_name, + run_name=run_name, + wandb_group_name=wandb_group_name, ) procs.append(consumer) ray.get([p.setup.remote() for p in procs]) diff --git a/applications/ColossalChat/coati/distributed/producer.py b/applications/ColossalChat/coati/distributed/producer.py index 483ad74b3..fa2d1cc8d 100644 --- a/applications/ColossalChat/coati/distributed/producer.py +++ b/applications/ColossalChat/coati/distributed/producer.py @@ -6,8 +6,11 @@ import ray import ray.util.collective as cc import torch import tqdm +import wandb from coati.dataset.loader import RawConversationDataset from coati.distributed.reward.reward_fn import boxed_math_reward_fn, math_reward_fn +from ray.util.collective import allreduce +from ray.util.collective.types import Backend, ReduceOp from torch.utils.data import DataLoader, DistributedSampler from transformers import AutoTokenizer @@ -15,7 +18,7 @@ from colossalai.utils import get_current_device from .comm import ray_broadcast_tensor_dict from .inference_backend import BACKEND_MAP -from .utils import pre_send, safe_write_jsonl +from .utils import pre_send, safe_append_to_jsonl_file try: from vllm import SamplingParams @@ -43,6 +46,9 @@ class BaseProducer: eval_interval=-1, # disable evaluation evaluation_function_type="think_answer_tags", eval_save_dir: str = "./eval", + project_name: str = None, + run_name: str = None, + wandb_group_name: str = None, ): self.producer_idx = producer_idx self.num_producers = num_producers @@ -61,6 +67,14 @@ class BaseProducer: self.eval_interval = eval_interval self.eval_save_dir = eval_save_dir self.consumer_global_step = 0 + if self.producer_idx == 0: + self.wandb_run = wandb.init( + project=project_name, + sync_tensorboard=True, + dir="./wandb", + name=run_name + "_eval", + group=wandb_group_name, + ) if os.path.exists(self.eval_save_dir): raise ValueError(f"Eval save dir {self.eval_save_dir} already exists. Please delete it or change the name.") @@ -132,13 +146,18 @@ class BaseProducer: self.consumer_pp_size = consumer_plugin_config.get("pp_size", 1) # consumer pp size def setup(self) -> None: + cc.init_collective_group( + world_size=self.num_producers, + rank=self.producer_idx, + backend=Backend.NCCL, + group_name="producer_group", + ) cc.init_collective_group(1 + self.num_consumer_procs, 0, group_name=f"sync_data_{self.producer_idx}") if self.consumer_pp_size > 1: for i in range(self.consumer_pp_size): cc.init_collective_group(self.num_producers + 1, self.producer_idx, group_name=f"sync_model_{i}") else: cc.init_collective_group(self.num_producers + 1, self.producer_idx, group_name="sync_model") - cc.init_collective_group(1 + self.num_consumer_procs, 0, group_name=f"sync_eval_statistics_{self.producer_idx}") def rollout(self, input_ids: torch.Tensor, attention_mask: torch.Tensor, **kwargs) -> Dict[str, torch.Tensor]: raise NotImplementedError @@ -160,13 +179,14 @@ class BaseProducer: break if self.eval_interval > 0 and self.eval_dataset_config is not None: if i % self.eval_interval == 0: - eval_statistics = {} + to_log_msg = {} for eval_task_name in self.eval_dataloaders: - print( - f"[P{self.producer_idx}] Evaluate episode {episode} step {i} on task {eval_task_name}" - ) + if self.producer_idx == 0: + print( + f"[P{self.producer_idx}] Evaluate episode {episode} step {i} on task {eval_task_name}" + ) eval_results = [] - eval_statistics[eval_task_name] = torch.zeros(2, device=self.device) + eval_statistics_tensor = torch.zeros((2,), dtype=torch.float32).to(self.device) for eval_batch in tqdm.tqdm( self.eval_dataloaders[eval_task_name], disable=self.producer_idx != 0 ): @@ -182,24 +202,27 @@ class BaseProducer: for m in range(eval_outputs["input_ids"].size(0)) for n in range(eval_outputs["input_ids"].size(1)) ] - eval_statistics[eval_task_name][0] += len( - [res for res in eval_results if res["ans_valid"] == 1] + eval_statistics_tensor[0] += len([res for res in eval_results if res["ans_valid"] == 1]) + eval_statistics_tensor[1] += len(eval_results) + allreduce(eval_statistics_tensor, op=ReduceOp.SUM, group_name="producer_group") + to_log_msg[f"eval/{eval_task_name}"] = ( + eval_statistics_tensor[0].item() / eval_statistics_tensor[1].item() ) - eval_statistics[eval_task_name][1] += len(eval_results) + if self.producer_idx == 0: + print( + f"[P{self.producer_idx}]: Accuracy on {eval_task_name}: {to_log_msg[f'eval/{eval_task_name}']}" + ) # save eval results - result_file_name = os.path.join( - self.eval_save_dir, - f"{eval_task_name}_episode_{episode}_step_{self.consumer_global_step}.jsonl", + safe_append_to_jsonl_file( + os.path.join( + self.eval_save_dir, + f"{eval_task_name}_episode_{episode}_step_{self.consumer_global_step}.jsonl", + ), + eval_results, ) - # delete the file if it exists - safe_write_jsonl(result_file_name, eval_results) - print(f"[P{self.producer_idx}] Send eval statistics episode {episode} step {i}") - ray_broadcast_tensor_dict( - eval_statistics, - src=0, - device=self.device, - group_name=f"sync_eval_statistics_{self.producer_idx}", - ) + + if self.producer_idx == 0: + self.wandb_run.log(to_log_msg, step=self.consumer_global_step) outputs = self.rollout(**batch) print(f"[P{self.producer_idx}] Send data {[(k, v.shape) for k, v in outputs.items()]}") @@ -248,12 +271,11 @@ class BaseProducer: # linear annealing for 1 episode, temperature from initial to 0.9 if episode <= 0: ratio = 1 - (len(self.train_dataloader) - i) / len(self.train_dataloader) - if isinstance(self.model.generate_config.temperature, dict): - self.model.generate_config["temperature"] = (1 - ratio) * self.generate_config[ - "temperature" - ] + ratio * 0.9 - else: - self.model.generate_config.temperature = (1 - ratio) * self.generate_config[ + self.model.generate_config["temperature"] = (1 - ratio) * self.generate_config[ + "temperature" + ] + ratio * 0.9 + if isinstance(self.model, BACKEND_MAP["vllm"]): + self.model.sample_params.temperature = (1 - ratio) * self.generate_config[ "temperature" ] + ratio * 0.9 @@ -280,6 +302,10 @@ class SimpleProducer(BaseProducer): eval_interval=-1, # disable evaluation evaluation_function_type="think_answer_tags", eval_save_dir: str = "./eval", + eval_generation_config={}, + project_name: str = None, + run_name: str = None, + wandb_group_name: str = None, ): super().__init__( producer_idx, @@ -299,10 +325,14 @@ class SimpleProducer(BaseProducer): eval_interval=eval_interval, evaluation_function_type=evaluation_function_type, eval_save_dir=eval_save_dir, + project_name=project_name, + run_name=run_name, + wandb_group_name=wandb_group_name, ) self.model = self.backend_cls(model_config, generate_config, self.tokenizer, num_generations) self.eval_generation_config = copy.deepcopy(self.model.generate_config) self.eval_generation_config["n"] = 1 # use 1 generation for evaluation + self.eval_generation_config.update(eval_generation_config) self.eval_sample_params = SamplingParams(**self.eval_generation_config) @torch.no_grad() diff --git a/applications/ColossalChat/coati/distributed/utils.py b/applications/ColossalChat/coati/distributed/utils.py index fb2a0dfd3..a40ebbcfb 100644 --- a/applications/ColossalChat/coati/distributed/utils.py +++ b/applications/ColossalChat/coati/distributed/utils.py @@ -135,7 +135,7 @@ def masked_sum(tensor: torch.Tensor, mask: torch.Tensor, dim: int = 1) -> torch. return tensor.sum(dim=dim) -def safe_write_jsonl(file_path, data): +def safe_append_to_jsonl_file(file_path, data): with FileLock(file_path + ".lock"): # Ensure file exists os.makedirs(os.path.dirname(file_path), exist_ok=True) diff --git a/applications/ColossalChat/rl_example.py b/applications/ColossalChat/rl_example.py index b2fafd42e..9fecdb52d 100644 --- a/applications/ColossalChat/rl_example.py +++ b/applications/ColossalChat/rl_example.py @@ -161,6 +161,7 @@ if __name__ == "__main__": stop_strings=[""] if args.reward_type == "think_answer_tags" else None, ) ) + eval_generation_config = {"temperature": 0.6} # used to update generation config for evaluation elif args.backend == "vllm": inference_model_config.update( dict( @@ -179,6 +180,7 @@ if __name__ == "__main__": stop=[""] if args.reward_type == "think_answer_tags" else None, ) ) + eval_generation_config = {"temperature": 0.6} # used to update generation config for evaluation else: raise ValueError(f"Unsupported backend: {args.backend}") @@ -257,4 +259,5 @@ if __name__ == "__main__": }, eval_interval=args.eval_interval, eval_save_dir=os.path.join(args.eval_save_dir, args.project.replace(" ", "_")), + eval_generation_config=eval_generation_config, ) From 4ec73298b22315ba27ef68df4c4836ce9be0e5ee Mon Sep 17 00:00:00 2001 From: YeAnbang Date: Thu, 15 May 2025 14:15:40 +0800 Subject: [PATCH 3/8] use consumer global step --- .../ColossalChat/coati/distributed/grpo_consumer.py | 2 -- applications/ColossalChat/coati/distributed/launch.py | 1 - applications/ColossalChat/coati/distributed/producer.py | 9 +++++++-- 3 files changed, 7 insertions(+), 5 deletions(-) diff --git a/applications/ColossalChat/coati/distributed/grpo_consumer.py b/applications/ColossalChat/coati/distributed/grpo_consumer.py index 0b0c8a9ad..5ad73f45c 100644 --- a/applications/ColossalChat/coati/distributed/grpo_consumer.py +++ b/applications/ColossalChat/coati/distributed/grpo_consumer.py @@ -266,7 +266,6 @@ class GRPOConsumer(BaseConsumer): total_samples = all_reduce_sum(torch.sum(torch.ones_like(loss_mask, device=loss_mask.device)), self.plugin) self.effective_sample_count += effective_samples.item() self.total_sample_count += total_samples.item() - pbar.set_postfix( { "Global Step": self.global_step, @@ -522,7 +521,6 @@ class GRPOConsumer(BaseConsumer): # All gather excessive prompts index across DP ranks. excessive_prompts_idx = [idx + self.dp_rank * self.minibatch_size for idx in excessive_prompts_idx] excessive_prompts_idx = all_gather_tensors(excessive_prompts_idx, self.plugin) - return loss_scalar, excessive_prompts_idx else: return None, excessive_prompts_idx diff --git a/applications/ColossalChat/coati/distributed/launch.py b/applications/ColossalChat/coati/distributed/launch.py index cb5b6e4e2..327db1b55 100644 --- a/applications/ColossalChat/coati/distributed/launch.py +++ b/applications/ColossalChat/coati/distributed/launch.py @@ -56,7 +56,6 @@ def launch_distributed( eval_save_dir: Optional[str] = None, eval_generation_config: Optional[Dict[str, Any]] = None, ): - if core_algo not in ALGO_MAP: raise NotImplementedError(f"{core_algo} is not supported yet.") else: diff --git a/applications/ColossalChat/coati/distributed/producer.py b/applications/ColossalChat/coati/distributed/producer.py index fa2d1cc8d..f7c17bf56 100644 --- a/applications/ColossalChat/coati/distributed/producer.py +++ b/applications/ColossalChat/coati/distributed/producer.py @@ -58,6 +58,7 @@ class BaseProducer: self.microbatch_size = microbatch_size assert batch_size % microbatch_size == 0 self.num_microbatches = batch_size // microbatch_size + self.lastest_eval_step = -1 self.train_dataset_config = train_dataset_config self.model_config = model_config @@ -178,12 +179,15 @@ class BaseProducer: if i >= num_valid_microbatches: break if self.eval_interval > 0 and self.eval_dataset_config is not None: - if i % self.eval_interval == 0: + if ( + self.consumer_global_step % self.eval_interval == 0 + and self.consumer_global_step > self.lastest_eval_step + ): to_log_msg = {} for eval_task_name in self.eval_dataloaders: if self.producer_idx == 0: print( - f"[P{self.producer_idx}] Evaluate episode {episode} step {i} on task {eval_task_name}" + f"[P{self.producer_idx}] Evaluate episode {episode} step {self.consumer_global_step} on task {eval_task_name}" ) eval_results = [] eval_statistics_tensor = torch.zeros((2,), dtype=torch.float32).to(self.device) @@ -223,6 +227,7 @@ class BaseProducer: if self.producer_idx == 0: self.wandb_run.log(to_log_msg, step=self.consumer_global_step) + self.lastest_eval_step = self.consumer_global_step outputs = self.rollout(**batch) print(f"[P{self.producer_idx}] Send data {[(k, v.shape) for k, v in outputs.items()]}") From 957e3a521aa41d0580c6ed32d77b4a9cf8581210 Mon Sep 17 00:00:00 2001 From: YeAnbang Date: Thu, 15 May 2025 16:52:31 +0800 Subject: [PATCH 4/8] disable wandb tb syncing --- applications/ColossalChat/coati/distributed/grpo_consumer.py | 2 +- applications/ColossalChat/coati/distributed/producer.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/applications/ColossalChat/coati/distributed/grpo_consumer.py b/applications/ColossalChat/coati/distributed/grpo_consumer.py index 5ad73f45c..7ea2dbf18 100644 --- a/applications/ColossalChat/coati/distributed/grpo_consumer.py +++ b/applications/ColossalChat/coati/distributed/grpo_consumer.py @@ -153,7 +153,7 @@ class GRPOConsumer(BaseConsumer): ): self.wandb_run = wandb.init( project=self.project_name, - sync_tensorboard=True, + sync_tensorboard=False, dir="./wandb", name=self.run_name, group=self.wandb_group_name, diff --git a/applications/ColossalChat/coati/distributed/producer.py b/applications/ColossalChat/coati/distributed/producer.py index f7c17bf56..8b9452160 100644 --- a/applications/ColossalChat/coati/distributed/producer.py +++ b/applications/ColossalChat/coati/distributed/producer.py @@ -71,7 +71,7 @@ class BaseProducer: if self.producer_idx == 0: self.wandb_run = wandb.init( project=project_name, - sync_tensorboard=True, + sync_tensorboard=False, dir="./wandb", name=run_name + "_eval", group=wandb_group_name, From 1644adf6843b735af84e76c3346bea2e7ceee85d Mon Sep 17 00:00:00 2001 From: Tong Li Date: Thu, 15 May 2025 18:30:27 +0800 Subject: [PATCH 5/8] handle empty index --- .../coati/distributed/consumer.py | 48 +++++++++---------- .../coati/distributed/grpo_consumer.py | 25 ++++++---- 2 files changed, 37 insertions(+), 36 deletions(-) diff --git a/applications/ColossalChat/coati/distributed/consumer.py b/applications/ColossalChat/coati/distributed/consumer.py index ecfd65458..816cab50a 100644 --- a/applications/ColossalChat/coati/distributed/consumer.py +++ b/applications/ColossalChat/coati/distributed/consumer.py @@ -113,7 +113,6 @@ class BaseConsumer: ) as pbar: for step in pbar: i = 0 - allow_sync_model = False for _ in range(self.num_recv_per_update): # receive data from producers for r in range(self.num_producers): @@ -139,7 +138,6 @@ class BaseConsumer: else: self.buffer = self.buffer[self.dp_size * self.minibatch_size :] if loss is not None: - allow_sync_model = True pbar.set_postfix({"loss": loss}) i += 1 if self.lr_scheduler is not None: @@ -153,31 +151,29 @@ class BaseConsumer: print(f"Saved model checkpoint at step {step + 1} in folder {save_path}") if episode != self.num_episodes - 1 or step != self.num_update_per_episode - 1: - if allow_sync_model: - if self.pp_size > 1: - print( - f"[T{dist.get_rank()}] Sync model PP stage {self.pp_rank} episode {episode} step {step}" + if self.pp_size > 1: + print( + f"[T{dist.get_rank()}] Sync model PP stage {self.pp_rank} episode {episode} step {step}" + ) + else: + print(f"[T{dist.get_rank()}] Sync model episode {episode} step {step}") + torch.cuda.empty_cache() + state_dict = self.state_dict() + if self.pp_size > 1: + if self.tp_rank == 0 and self.dp_rank == 0: + ray_broadcast_tensor_dict( + state_dict, + src=self.num_producers, + device=self.device, + group_name=f"sync_model_{self.pp_rank}", ) - else: - print(f"[T{dist.get_rank()}] Sync model episode {episode} step {step}") - torch.cuda.empty_cache() - state_dict = self.state_dict() - if self.pp_size > 1: - if self.tp_rank == 0 and self.dp_rank == 0: - ray_broadcast_tensor_dict( - state_dict, - src=self.num_producers, - device=self.device, - group_name=f"sync_model_{self.pp_rank}", - ) - else: - if self.rank == 0: - ray_broadcast_tensor_dict( - state_dict, src=self.num_producers, device=self.device, group_name="sync_model" - ) - del state_dict - torch.cuda.empty_cache() - allow_sync_model = False + else: + if self.rank == 0: + ray_broadcast_tensor_dict( + state_dict, src=self.num_producers, device=self.device, group_name="sync_model" + ) + del state_dict + torch.cuda.empty_cache() @ray.remote diff --git a/applications/ColossalChat/coati/distributed/grpo_consumer.py b/applications/ColossalChat/coati/distributed/grpo_consumer.py index 7ea2dbf18..eae4ff54e 100644 --- a/applications/ColossalChat/coati/distributed/grpo_consumer.py +++ b/applications/ColossalChat/coati/distributed/grpo_consumer.py @@ -245,17 +245,22 @@ class GRPOConsumer(BaseConsumer): # TODO: customize excessive prompts calculation. if excessive_prompts_per_rank != 0: # Mask excessive prompts to False - true_indices = torch.nonzero(effective_prompts_mask).squeeze() - if excessive_prompts_per_rank <= len(true_indices): - excessive_prompts_idx = true_indices[-excessive_prompts_per_rank:] - else: - excessive_prompts_idx = true_indices - effective_prompts_mask[excessive_prompts_idx] = False + true_indices = torch.nonzero(effective_prompts_mask) + # Make sure the indices are not empty. + if true_indices.numel() > 0: + true_indices = true_indices.squeeze() + if excessive_prompts_per_rank <= len(true_indices): + excessive_prompts_idx = true_indices[-excessive_prompts_per_rank:] + else: + excessive_prompts_idx = true_indices + effective_prompts_mask[excessive_prompts_idx] = False - for mask_idx in range(len(effective_prompts_mask)): - if effective_prompts_mask[mask_idx] == False: - # Update loss mask. - loss_mask[mask_idx] = False + for mask_idx in range(len(effective_prompts_mask)): + if effective_prompts_mask[mask_idx] == False: + # Update loss mask. + loss_mask[mask_idx] = False + else: + excessive_prompts_idx = torch.empty([0]) else: # If dynamic batching is disabled, we need to use all samples for training. need_update = (step_idx + 1) % self.num_microbatches == 0 From 6abffb9100fe6abe002f484ee2b1eedb17ed204b Mon Sep 17 00:00:00 2001 From: YeAnbang Date: Fri, 16 May 2025 09:42:35 +0800 Subject: [PATCH 6/8] fix evaluation --- applications/ColossalChat/coati/distributed/producer.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/applications/ColossalChat/coati/distributed/producer.py b/applications/ColossalChat/coati/distributed/producer.py index 8b9452160..8b8be8e16 100644 --- a/applications/ColossalChat/coati/distributed/producer.py +++ b/applications/ColossalChat/coati/distributed/producer.py @@ -180,7 +180,7 @@ class BaseProducer: break if self.eval_interval > 0 and self.eval_dataset_config is not None: if ( - self.consumer_global_step % self.eval_interval == 0 + self.consumer_global_step - self.lastest_eval_step >= self.eval_interval and self.consumer_global_step > self.lastest_eval_step ): to_log_msg = {} @@ -256,6 +256,8 @@ class BaseProducer: state_dict = ray_broadcast_tensor_dict( None, self.num_producers, device=self.device, group_name=f"sync_model_{pp_idx}" ) + if "consumer_global_step" in state_dict: + self.consumer_global_step = state_dict.pop("consumer_global_step").item() self.load_state_dict(state_dict) else: print( From 203dfb15365beb2c63a364665e342ac44da25111 Mon Sep 17 00:00:00 2001 From: YeAnbang Date: Fri, 16 May 2025 14:15:35 +0800 Subject: [PATCH 7/8] address conversation --- applications/ColossalChat/coati/distributed/launch.py | 2 +- applications/ColossalChat/coati/distributed/producer.py | 4 ++-- applications/ColossalChat/rl_example.py | 8 +++++++- 3 files changed, 10 insertions(+), 4 deletions(-) diff --git a/applications/ColossalChat/coati/distributed/launch.py b/applications/ColossalChat/coati/distributed/launch.py index 327db1b55..e6367d2d1 100644 --- a/applications/ColossalChat/coati/distributed/launch.py +++ b/applications/ColossalChat/coati/distributed/launch.py @@ -66,7 +66,7 @@ def launch_distributed( dataset_path = train_dataset_config["path"] num_samples = get_jsonl_size_fast(dataset_path) - global_inference_batch_size = inference_batch_size * num_producers # TODO: this doesn't support TP on producer + global_inference_batch_size = inference_batch_size * num_producers num_update_per_episode = num_samples // global_inference_batch_size num_recv_per_update = inference_batch_size // inference_microbatch_size diff --git a/applications/ColossalChat/coati/distributed/producer.py b/applications/ColossalChat/coati/distributed/producer.py index 8b8be8e16..82f57094d 100644 --- a/applications/ColossalChat/coati/distributed/producer.py +++ b/applications/ColossalChat/coati/distributed/producer.py @@ -187,7 +187,7 @@ class BaseProducer: for eval_task_name in self.eval_dataloaders: if self.producer_idx == 0: print( - f"[P{self.producer_idx}] Evaluate episode {episode} step {self.consumer_global_step} on task {eval_task_name}" + f"[P{self.producer_idx}] Evaluate model at training step {self.consumer_global_step} on task {eval_task_name}" ) eval_results = [] eval_statistics_tensor = torch.zeros((2,), dtype=torch.float32).to(self.device) @@ -220,7 +220,7 @@ class BaseProducer: safe_append_to_jsonl_file( os.path.join( self.eval_save_dir, - f"{eval_task_name}_episode_{episode}_step_{self.consumer_global_step}.jsonl", + f"{eval_task_name}_training_step_{self.consumer_global_step}.jsonl", ), eval_results, ) diff --git a/applications/ColossalChat/rl_example.py b/applications/ColossalChat/rl_example.py index 2ed9ef62c..8d1f25e74 100644 --- a/applications/ColossalChat/rl_example.py +++ b/applications/ColossalChat/rl_example.py @@ -104,7 +104,13 @@ if __name__ == "__main__": choices=["think_answer_tags", "boxed"], help="Reward type for GRPO.", ) - parser.add_argument("-ei", "--eval-interval", type=int, default=100, help="Interval for evaluation.") + parser.add_argument( + "-ei", + "--eval-interval", + type=int, + default=100, + help="Interval for evaluation. Evaluate every ei training steps.", + ) # Logging/Checkpointing parameters parser.add_argument("-si", "--save-interval", type=int, default=100, help="Interval for saving checkpoints.") From 021914c565ed5310671b974cc28b4525bf3a5a86 Mon Sep 17 00:00:00 2001 From: YeAnbang Date: Fri, 16 May 2025 15:56:03 +0800 Subject: [PATCH 8/8] support logging rollouts to wandb --- .../coati/distributed/producer.py | 42 +++++++++++++++---- 1 file changed, 33 insertions(+), 9 deletions(-) diff --git a/applications/ColossalChat/coati/distributed/producer.py b/applications/ColossalChat/coati/distributed/producer.py index 82f57094d..0d91f43f1 100644 --- a/applications/ColossalChat/coati/distributed/producer.py +++ b/applications/ColossalChat/coati/distributed/producer.py @@ -49,6 +49,7 @@ class BaseProducer: project_name: str = None, run_name: str = None, wandb_group_name: str = None, + wandb_log_rollout_interval: int = 20, ): self.producer_idx = producer_idx self.num_producers = num_producers @@ -58,7 +59,7 @@ class BaseProducer: self.microbatch_size = microbatch_size assert batch_size % microbatch_size == 0 self.num_microbatches = batch_size // microbatch_size - self.lastest_eval_step = -1 + self.latest_eval_step = -1 self.train_dataset_config = train_dataset_config self.model_config = model_config @@ -68,6 +69,10 @@ class BaseProducer: self.eval_interval = eval_interval self.eval_save_dir = eval_save_dir self.consumer_global_step = 0 + self.eval_mode = False + self.wandb_rollout_data = [] + self.wandb_log_rollout_interval = wandb_log_rollout_interval + self.latest_rollout_log_step = -1 if self.producer_idx == 0: self.wandb_run = wandb.init( project=project_name, @@ -77,7 +82,7 @@ class BaseProducer: group=wandb_group_name, ) - if os.path.exists(self.eval_save_dir): + if os.path.exists(self.eval_save_dir) and self.eval_interval > 0: raise ValueError(f"Eval save dir {self.eval_save_dir} already exists. Please delete it or change the name.") # init tokenizer @@ -180,10 +185,11 @@ class BaseProducer: break if self.eval_interval > 0 and self.eval_dataset_config is not None: if ( - self.consumer_global_step - self.lastest_eval_step >= self.eval_interval - and self.consumer_global_step > self.lastest_eval_step - ): + self.consumer_global_step - self.latest_eval_step >= self.eval_interval + and self.consumer_global_step > self.latest_eval_step + ) or self.latest_eval_step == -1: to_log_msg = {} + self.eval_mode = True for eval_task_name in self.eval_dataloaders: if self.producer_idx == 0: print( @@ -227,7 +233,8 @@ class BaseProducer: if self.producer_idx == 0: self.wandb_run.log(to_log_msg, step=self.consumer_global_step) - self.lastest_eval_step = self.consumer_global_step + self.eval_mode = False + self.latest_eval_step = self.consumer_global_step outputs = self.rollout(**batch) print(f"[P{self.producer_idx}] Send data {[(k, v.shape) for k, v in outputs.items()]}") @@ -345,9 +352,26 @@ class SimpleProducer(BaseProducer): @torch.no_grad() def rollout(self, input_ids, attention_mask, **kwargs): rollouts = self.model.generate(input_ids, attention_mask, **kwargs) - # if self.producer_idx == 1: - # print("Rollout example:\n", self.tokenizer.decode(rollouts["input_ids"][0][0], skip_special_tokens=True)) - + if self.producer_idx == 0 and not self.eval_mode: + wandb_rollout_data = self.wandb_rollout_data + [ + [ + str(self.consumer_global_step), + str(self.tokenizer.decode(rollouts["input_ids"][0][0], skip_special_tokens=True)), + ] + ] + if ( + self.consumer_global_step - self.latest_rollout_log_step >= self.wandb_log_rollout_interval + or self.latest_rollout_log_step == -1 + ): + self.wandb_rollout_data = wandb_rollout_data + self.latest_rollout_log_step = self.consumer_global_step + self.wandb_run.log( + { + "rollout/rollout_examples": wandb.Table( + columns=["train_step", "rollout_examples"], data=wandb_rollout_data + ) + } + ) return rollouts def load_state_dict(self, state_dict):