diff --git a/applications/ColossalChat/coati/distributed/consumer.py b/applications/ColossalChat/coati/distributed/consumer.py index 0cb35b94b..28d83fa40 100644 --- a/applications/ColossalChat/coati/distributed/consumer.py +++ b/applications/ColossalChat/coati/distributed/consumer.py @@ -33,7 +33,7 @@ class BaseConsumer: batch_size: int, model_config: Dict[str, Any], plugin_config: Dict[str, Any], - microbatch_size: int = 1, + minibatch_size: int = 1, save_interval: int = 100, save_dir: str = "./model", ): @@ -46,11 +46,11 @@ class BaseConsumer: self.num_update_per_episode = num_update_per_episode self.num_recv_per_update = num_recv_per_update self.batch_size = batch_size - self.microbatch_size = microbatch_size + self.minibatch_size = minibatch_size self.save_interval = save_interval self.save_dir = save_dir - assert batch_size % microbatch_size == 0, "batch_size should be divisible by microbatch_size" - self.num_microbatches = batch_size // microbatch_size + assert batch_size % minibatch_size == 0, "batch_size should be divisible by microbatch_size" + self.num_microbatches = batch_size // minibatch_size self.model_config = model_config self.plugin_config = plugin_config @@ -67,7 +67,7 @@ class BaseConsumer: plugin_config = dict(tp_size=1, pp_size=1, precision="bf16", zero_stage=2) if self.plugin_config.get("pp_size", 1) > 1 and "num_microbatches" not in self.plugin_config: - plugin_config["microbatch_size"] = self.microbatch_size + plugin_config["microbatch_size"] = self.minibatch_size plugin_config.update(self.plugin_config) self.plugin = HybridParallelPlugin(**plugin_config) self.booster = Booster(plugin=self.plugin) @@ -105,22 +105,22 @@ class BaseConsumer: ) ) ) - while len(self.buffer) >= self.dp_size * self.microbatch_size: + while len(self.buffer) >= self.dp_size * self.minibatch_size: batches = self.buffer[ - self.dp_rank * self.microbatch_size : (self.dp_rank + 1) * self.microbatch_size + self.dp_rank * self.minibatch_size : (self.dp_rank + 1) * self.minibatch_size ] batch = pad_batch( batches ) # when `imbs` is smaller than `tMbs`, samples may have differ in size, need to pad before stacking batch = bind_batch(batches) batch = post_recv(batch) - loss, num_excessive_rollouts = self.step(i, pbar, **batch) + loss, num_excessive_prompts = self.step(i, pbar, **batch) self.buffer = ( self.buffer[ - (self.dp_rank + 1) * self.microbatch_size - - num_excessive_rollouts : (self.dp_rank + 1) * self.microbatch_size + (self.dp_rank + 1) * self.minibatch_size + - num_excessive_prompts : (self.dp_rank + 1) * self.minibatch_size ] - + self.buffer[self.dp_size * self.microbatch_size :] + + self.buffer[self.dp_size * self.minibatch_size :] ) if loss is not None: pbar.set_postfix({"loss": loss}) @@ -162,7 +162,8 @@ class SimpleConsumer(BaseConsumer): batch_size, model_config, plugin_config, - microbatch_size=1, + minibatch_size=1, + save_interval: int = 100, save_dir="./model", ): super().__init__( @@ -177,7 +178,7 @@ class SimpleConsumer(BaseConsumer): batch_size, model_config, plugin_config, - microbatch_size, + minibatch_size, ) path = model_config.pop("path") self.model = AutoModelForCausalLM.from_pretrained(path, **model_config) diff --git a/applications/ColossalChat/coati/distributed/grpo_consumer.py b/applications/ColossalChat/coati/distributed/grpo_consumer.py index 490c93b86..831612f80 100644 --- a/applications/ColossalChat/coati/distributed/grpo_consumer.py +++ b/applications/ColossalChat/coati/distributed/grpo_consumer.py @@ -1,12 +1,9 @@ -import json -import os import warnings from contextlib import nullcontext from typing import Any, Optional import ray import torch -import torch.distributed as dist import wandb from coati.distributed.consumer import BaseConsumer from coati.distributed.loss import PolicyLoss @@ -35,22 +32,23 @@ class GRPOConsumer(BaseConsumer): batch_size, model_config, plugin_config, - microbatch_size=1, + minibatch_size=1, num_generations=8, use_wandb=True, generate_config=None, grpo_config={}, project_name=None, + save_interval: int = 100, save_dir="./model", ): print(f"Using GRPO config: {grpo_config}") if grpo_config.get("loss_variation", "sample_level") == "token_level": - if batch_size != microbatch_size: + if batch_size != minibatch_size: warnings.warn( - f"Applied token_level loss, force overwrite mini-batch-size with batch-size: mini-batch-size: {microbatch_size}->{batch_size}", + f"Applied token_level loss, force overwrite mini-batch-size with batch-size: mini-batch-size: {minibatch_size}->{batch_size}", UserWarning, ) - microbatch_size = batch_size + minibatch_size = batch_size super().__init__( num_producers, num_episodes, @@ -63,7 +61,7 @@ class GRPOConsumer(BaseConsumer): batch_size, model_config, plugin_config, - microbatch_size, + minibatch_size, save_dir=save_dir, ) path = model_config.pop("path") @@ -166,10 +164,10 @@ class GRPOConsumer(BaseConsumer): }, ...] Format: - [batch_size, num_of_generation, prompt_length + response_length] --- ............. + [minibatch_size, num_of_generation, prompt_length + response_length] --- ............. """ - # Reshape to [batch_size x num_of_generation, prompt_length + response_length] + # Reshape to [minibatch_size x num_of_generation, prompt_length + response_length] data = {k: v.view(-1, v.size(-1)) for k, v in kwargs.items()} action_mask = data["action_mask"] num_action = action_mask.shape[1] @@ -187,15 +185,15 @@ class GRPOConsumer(BaseConsumer): format_acc = torch.tensor([value[1] for value in reward_group]).to(data["input_ids"].device) ans_acc = torch.tensor([value[2] for value in reward_group]).to(data["input_ids"].device) - # [batch_size, num_generations] + # [minibatch_size, num_generations] group_reward = reward.view(-1, self.num_generations) reward_mean = group_reward.mean(dim=1) - # [batch_size x num_generations] + # [minibatch_size x num_generations] reward_mean = reward_mean.repeat_interleave(self.num_generations, dim=0) reward_std = group_reward.std(dim=1).repeat_interleave(self.num_generations, dim=0) - # [batch_size x num_generations] + # [minibatch_size x num_generations] advantages = ((reward - reward_mean) / (reward_std + 1e-4)).unsqueeze(dim=-1) # filter out the reward that is too high (all sample gets full score) or too low (all sample gets 0 score), group_ans_acc = ( @@ -522,125 +520,3 @@ class GRPOConsumer(BaseConsumer): model = self.policy_model.unwrap() state_dict = model.state_dict() return state_dict - - -@ray.remote -class GRPOEvalConsumer(BaseConsumer): - def __init__( - self, - num_producers, - num_episodes, - rank, - world_size, - master_addr, - master_port, - num_update_per_episode, - num_recv_per_update, - batch_size, - model_config, - plugin_config, - microbatch_size=1, - num_generations=4, - use_wandb=True, - log_dir="./results", - ): - super().__init__( - num_producers, - num_episodes, - rank, - world_size, - master_addr, - master_port, - num_update_per_episode, - num_recv_per_update, - batch_size, - model_config, - plugin_config, - microbatch_size, - ) - path = model_config.pop("path") - self.policy_model = AutoModelForCausalLM.from_pretrained(path, **model_config) - self.policy_model.train() - self.accum_reward = torch.zeros(1, device=self.device) - self.accum_format_acc = torch.zeros(1, device=self.device) - self.accum_ans_acc = torch.zeros(1, device=self.device) - self.accum_response_length = torch.zeros(1, device=self.device) - self.accum_count = torch.zeros(1, device=self.device) - - self.tokenizer = AutoTokenizer.from_pretrained(path) - self.pad_token_id = self.tokenizer.pad_token_id - self.num_generations = num_generations - - # Initialize verifiable reward. - response_format_tags = { - "think_start": {"text": "", "num_occur": 1}, - "think_end": {"text": "", "num_occur": 1}, - "answer_start": {"text": "", "num_occur": 1}, - "answer_end": {"text": "", "num_occur": 1}, - } - self.reward_model = VerifiableReward( - reward_fns=[math_reward_fn], tokenizer=self.tokenizer, tags=response_format_tags - ) - - self.log_dir = log_dir - if not os.path.exists(self.log_dir): - os.makedirs(self.log_dir) - else: - os.system(f"rm -rf {self.log_dir}/*") - - def setup(self): - super().setup() - self.policy_model, _, *_ = self.booster.boost(self.policy_model) - - def step(self, step_idx: int, **kwargs) -> Optional[float]: - rank = dist.get_rank() - data = {k: v.view(-1, v.size(-1)).cpu() for k, v in kwargs.items()} - kwargs["input_ids"].size(0) - reward_group = self.reward_model( - data["input_ids"], gt_answer=data["gt_answer"], response_idx=data["response_idx"] - ) - reward = [value[0].item() for value in reward_group] - format_acc = [value[1].item() for value in reward_group] - ans_acc = [value[2].item() for value in reward_group] - response_length = [(data["response_idx"][i][1] - data["response_idx"][i][0]).item() for i in range(len(reward))] - - response = self.tokenizer.batch_decode(data["input_ids"], skip_special_tokens=True) - with open(f"{self.log_dir}/eval_results_rank_{rank}.jsonl", "a", encoding="utf8") as f: - for i in range(len(response)): - f.write( - json.dumps( - { - "response": response[i], - "reward": reward[i], - "format_acc": format_acc[i], - "ans_acc": ans_acc[i], - "response_length": response_length[i], - }, - ensure_ascii=False, - ) - + "\n" - ) - - self.accum_reward += sum(reward) - self.accum_format_acc += sum(format_acc) - self.accum_ans_acc += sum(ans_acc) - self.accum_response_length += sum(response_length) - self.accum_count += len(reward) - - # print results - total_count = all_reduce_mean(self.accum_count, self.plugin) - mean_reward = all_reduce_mean(self.accum_reward, self.plugin) / total_count - mean_format_acc = all_reduce_mean(self.accum_format_acc, self.plugin) / total_count - mean_ans_acc = all_reduce_mean(self.accum_ans_acc, self.plugin) / total_count - mean_response_length = all_reduce_mean(self.accum_response_length, self.plugin) / total_count - if rank == 0: - print( - f"Step {step_idx}: Mean Reward: {mean_reward}, Mean Format Reward: {mean_format_acc}, Mean Acc Reward: {mean_ans_acc}, Mean Response Length: {mean_response_length}" - ) - return None - - def state_dict(self): - self.policy_model._force_wait_all_gather() - model = self.policy_model.unwrap() - state_dict = model.state_dict() - return state_dict diff --git a/applications/ColossalChat/coati/distributed/launch.py b/applications/ColossalChat/coati/distributed/launch.py index 30bd9cb16..b193dc8bc 100644 --- a/applications/ColossalChat/coati/distributed/launch.py +++ b/applications/ColossalChat/coati/distributed/launch.py @@ -4,10 +4,10 @@ from typing import Any, Dict, Optional import ray from .consumer import SimpleConsumer -from .grpo_consumer import GRPOConsumer, GRPOEvalConsumer +from .grpo_consumer import GRPOConsumer from .producer import SimpleProducer -ALGO_MAP = {"Simple": SimpleConsumer, "GRPO": GRPOConsumer, "EvalGRPO": GRPOEvalConsumer} +ALGO_MAP = {"Simple": SimpleConsumer, "GRPO": GRPOConsumer, "DAPO": GRPOConsumer} def get_jsonl_size_fast(path: str) -> int: @@ -49,6 +49,8 @@ def launch_distributed( master_port: int = 29500, core_algo: str = "GRPO", project_name: Optional[str] = None, + save_interval: int = 100, + save_dir: str = "./model", ): if core_algo not in ALGO_MAP: @@ -102,12 +104,13 @@ def launch_distributed( batch_size=train_batch_size, model_config=train_model_config, plugin_config=plugin_config, - microbatch_size=train_minibatch_size, + minibatch_size=train_minibatch_size, generate_config=generate_config_consumer, grpo_config=grpo_config, num_generations=num_generations, project_name=project_name, - save_dir=grpo_config.get("save_dir", f"./model/{project_name}"), + save_interval=save_interval, + save_dir=save_dir, ) procs.append(consumer) ray.get([p.setup.remote() for p in procs]) diff --git a/applications/ColossalChat/coati/distributed/producer.py b/applications/ColossalChat/coati/distributed/producer.py index f10a97794..b0624f72a 100644 --- a/applications/ColossalChat/coati/distributed/producer.py +++ b/applications/ColossalChat/coati/distributed/producer.py @@ -103,7 +103,14 @@ class BaseProducer: print(f"[P{self.producer_idx}] Send data {[(k, v.shape) for k, v in outputs.items()]}") outputs["temperature"] = torch.tensor( - [self.model.generate_config.temperature] * outputs["input_ids"].size(0) + [ + ( + self.model.generate_config["temperature"] + if isinstance(self.model.generate_config.temperature, dict) + else self.model.generate_config.temperature + ) + ] + * outputs["input_ids"].size(0) ).to(outputs["input_ids"].device) outputs = pre_send(outputs) ray_broadcast_tensor_dict( @@ -136,9 +143,14 @@ class BaseProducer: # linear annealing for 1 episode, temperature from initial to 0.9 if episode <= 0: ratio = 1 - (len(self.dataloader) - i) / len(self.dataloader) - self.model.generate_config.temperature = (1 - ratio) * self.generate_config[ - "temperature" - ] + ratio * 0.9 + if isinstance(self.model.generate_config.temperature, dict): + self.model.generate_config["temperature"] = (1 - ratio) * self.generate_config[ + "temperature" + ] + ratio * 0.9 + else: + self.model.generate_config.temperature = (1 - ratio) * self.generate_config[ + "temperature" + ] + ratio * 0.9 @ray.remote diff --git a/applications/ColossalChat/coati/distributed/reward/reward_fn.py b/applications/ColossalChat/coati/distributed/reward/reward_fn.py index 3cf7a1af3..fd9129130 100644 --- a/applications/ColossalChat/coati/distributed/reward/reward_fn.py +++ b/applications/ColossalChat/coati/distributed/reward/reward_fn.py @@ -5,8 +5,7 @@ from .reward_utils import extract_solution, validate_response_structure def math_reward_fn(input_ids, gt_answer, response_idx, **kwargs): tokenizer = kwargs["tokenizer"] - soft_over_length_punishment = kwargs["soft_over_length_punishment"] - format_score = 0.0 + soft_over_length_punishment = kwargs.get("soft_over_length_punishment", False) acc_score = 10.0 reward = torch.tensor(0.0) format_acc = torch.tensor(0.0) @@ -33,7 +32,6 @@ def math_reward_fn(input_ids, gt_answer, response_idx, **kwargs): # Check format accuracy if format_valid: format_acc += 1 - reward += format_score # Check answer accuracy, answer is considered correct if the answer is correct and the format is valid if ( diff --git a/applications/ColossalChat/rl_example.py b/applications/ColossalChat/rl_example.py index 64b8d7e68..9c0ec7922 100644 --- a/applications/ColossalChat/rl_example.py +++ b/applications/ColossalChat/rl_example.py @@ -1,4 +1,5 @@ import argparse +import os import ray import torch @@ -8,15 +9,17 @@ if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("-m", "--model", type=str, default="Qwen/Qwen2.5-7B") parser.add_argument("-d", "--dataset", type=str, default="data.jsonl") + parser.add_argument("-p", "--project", type=str, default="GRPO", help="Project name.") + + # Distributed training parameters parser.add_argument("-t", "--num-trainers", type=int, default=2) parser.add_argument("-i", "--num-inferencer", type=int, default=2) parser.add_argument("-g", "--num-generations", type=int, default=8, help="Number of generations per prompt.") - parser.add_argument("-p", "--project", type=str, default="GRPO", help="Project name.") parser.add_argument( "-ibs", "--inference-batch-size", type=int, - default=64, + default=None, help="Number of prompts to generate per inference step. It should be divisible by tbs, and the weights on the inference backend will be synced every ibs/tbs training steps of the policy model.", ) parser.add_argument( @@ -37,7 +40,7 @@ if __name__ == "__main__": "-tMbs", "--train-minibatch-size", type=int, - default=1, + default=None, help="Number of unique prompts in each training batch per dp group. The inference backend must generate tMbs * g * dp_size samples before forwarding. Satisfy tMbs * g >= tmbs", ) parser.add_argument( @@ -47,23 +50,61 @@ if __name__ == "__main__": default=2, help="Effective batch size per dp group for forwarding and backwarding. Please select based on the availiable memory.", ) - parser.add_argument("-b", "--backend", type=str, default="transformers", choices=["transformers", "vllm"]) - parser.add_argument("-a", "--algo", type=str, default="GRPO", choices=["Simple", "GRPO", "EvalGRPO"]) parser.add_argument( "--ray_dir", type=str, default=None, help="Custom temperary directory for storing ray cluster data, Optional" ) parser.add_argument( "--master_address", type=str, default=None, help="Master address for multi-node distributed training, Optional" ) + parser.add_argument( + "--master_port", type=int, default=29506, help="Master port for multi-node distributed training, Optional" + ) + + # Sampling parameters + parser.add_argument("-b", "--backend", type=str, default="transformers", choices=["transformers", "vllm"]) + parser.add_argument("-temp", "--temperature", type=float, default=1.0, help="Temperature for sampling.") + parser.add_argument( + "-topk", + "--top-k", + type=int, + default=None, + help="Top k for sampling. Please check the generation arguments documentation for your backend.", + ) + parser.add_argument( + "-topp", + "--top-p", + type=float, + default=1.0, + help="Top p for sampling. Please check the generation arguments documentation for your backend.", + ) parser.add_argument("-s", "--system-prompt", type=str, default=None, help="System prompt for data construction.") + parser.add_argument("-mnt", "--max-new-tokens", type=int, default=1024 * 4 - 512, help="Max length for generation.") + parser.add_argument("-mpt", "--max-prompt-tokens", type=int, default=512, help="Max length for prompt.") + + # GRPO parameters + parser.add_argument("-a", "--algo", type=str, default="GRPO", choices=["DAPO", "GRPO"]) + parser.add_argument("-lr", "--learning-rate", type=float, default=1e-6, help="Learning rate for GRPO.") + parser.add_argument("-kl", "--kl-coeff", type=float, default=0.01, help="KL penalty coefficient for GRPO.") + + # Logging/Checkpointing parameters + parser.add_argument("-si", "--save-interval", type=int, default=100, help="Interval for saving checkpoints.") + parser.add_argument("-sd", "--save-dir", type=str, default="./model", help="Directory for saving checkpoints.") + args = parser.parse_args() - assert args.train_minibatch_size > 0, "Train mini batch size must be greater than 0" + if args.train_minibatch_size is None: + # Default settings: Using train batch size as mini batch size + args.train_minibatch_size = args.train_batch_size + if args.inference_batch_size is None: + # Default settings: Using train batch size as inference batch size, sync every inference model every train step + args.inference_batch_size = args.train_batch_size assert ( args.train_minibatch_size * args.num_generations >= args.train_microbatch_size and args.train_microbatch_size > 0 ), "Train micro batch size must be greater than 0 less than train mini batch size * num generations" - assert args.train_minibatch_size < args.train_batch_size, "Train mini batch size must be less than train batch size" + assert ( + args.train_minibatch_size <= args.train_batch_size + ), "Train mini batch size must be less than or equals to train batch size" if args.master_address is None: # Default settings: Using single machine @@ -72,9 +113,15 @@ if __name__ == "__main__": # For ray distributed multi-machine training, Please change _node_ip_address to your IP address of your master node ray.init(_node_ip_address=args.master_address, namespace="ray-example", _temp_dir=args.ray_dir) + if args.top_k is None: + if args.backend == "transformers": + args.top_k = 50 + elif args.backend == "vllm": + args.top_k = -1 + inference_model_config = dict(path=args.model) train_model_config = dict(path=args.model, use_flash_attention_2=True, use_cache=False) - generate_config = dict(top_k=-1, top_p=1.0, temperature=1.0) + generate_config = dict(top_k=args.top_k, top_p=args.top_p, temperature=args.temperature) if args.backend == "transformers": inference_model_config.update( @@ -85,7 +132,7 @@ if __name__ == "__main__": ) generate_config.update( dict( - max_length=1024 + 512, + max_length=args.max_new_tokens + args.max_prompt_tokens, do_sample=True, max_new_tokens=None, early_stopping=False, @@ -98,54 +145,48 @@ if __name__ == "__main__": gpu_memory_utilization=0.7, enforce_eager=True, enable_chunked_prefill=True, - max_model_len=1024 * 4 + 510, + max_model_len=args.max_new_tokens + args.max_prompt_tokens, tensor_parallel_size=1, ) ) generate_config.update( dict( - max_tokens=1024 * 4, + max_tokens=args.max_new_tokens, # max new tokens ignore_eos=True, include_stop_str_in_output=True, stop=[""], ) ) else: - inference_model_config.update( - dict( - mem_fraction_static=0.6, - ) - ) - generate_config.update( - dict( - max_new_tokens=256, - ignore_eos=True, - ) - ) + raise ValueError(f"Unsupported backend: {args.backend}") - # Default Settings - # grpo_config = { - # "filter_range": [0.05, 9.0], - # "lr": 1e-6, - # "train_microbatch_size": train_microbatch_size, - # } - - # DAPO variant settings - grpo_config = { - "filter_range": [0.01, 0.99], # only filter out all zero batch and all one batch - "lr": 1e-6, - "train_microbatch_size": args.train_microbatch_size, - "dynamic_batching": True, - "clip_eps_low": 0.2, - "clip_eps_high": 0.28, - "skip_threshold": 20.0, - "beta": 0.0, # no KL penalty - "loss_variation": "token_level", - "soft_over_length_punishment": True, - "max_length": 1024 * 4, - "cache_length": 512, - "filter_truncated_response": True, - } + if args.algo == "GRPO": + # Default Settings + grpo_config = { + "lr": args.learning_rate, + "train_microbatch_size": args.train_microbatch_size, + "beta": args.kl_coeff, # KL penalty coefficient + "loss_variation": "sample_level", + } + elif args.algo == "DAPO": + # DAPO variant settings + grpo_config = { + "filter_range": [0.01, 0.99], # only filter out all zero batch and all one batch + "lr": args.learning_rate, + "train_microbatch_size": args.train_microbatch_size, + "dynamic_batching": True, + "clip_eps_low": 0.2, + "clip_eps_high": 0.28, + "skip_threshold": 20.0, + "beta": 0, # no KL penalty for DAPO + "loss_variation": "token_level", + "soft_over_length_punishment": True, + "max_length": args.max_new_tokens + args.max_prompt_tokens, + "cache_length": min(1024, int(args.max_new_tokens / 4)), + "filter_truncated_response": True, + } + else: + raise ValueError(f"Unsupported algorithm: {args.algo}") launch_distributed( num_producers=args.num_inferencer, @@ -157,7 +198,11 @@ if __name__ == "__main__": train_batch_size=args.train_batch_size, train_minibatch_size=args.train_minibatch_size, train_microbatch_size=args.train_microbatch_size, - dataset_config={"path": args.dataset, "max_length": 300, "system_prompt": args.system_prompt}, + dataset_config={ + "path": args.dataset, + "max_length": args.max_prompt_tokens, + "system_prompt": args.system_prompt, + }, dataloaders_config={}, inference_model_config=inference_model_config, generate_config=generate_config, @@ -167,6 +212,7 @@ if __name__ == "__main__": plugin_config={ "zero_stage": 2, }, # for zero + # currently not support tp/pp # plugin_config={ # "tp_size": 2, # "microbatch_size": args.train_microbatch_size // 2, @@ -175,7 +221,9 @@ if __name__ == "__main__": # }, # for pp inference_backend=args.backend, master_addr="localhost", - master_port=29506, + master_port=args.master_port, core_algo=args.algo, project_name=args.project, + save_interval=args.save_interval, + save_dir=os.path.join(args.save_dir, args.project.replace(" ", "_")), )