import warnings from contextlib import nullcontext from typing import Any, Optional import ray import torch import wandb from coati.distributed.consumer import BaseConsumer from coati.distributed.loss import PolicyLoss from coati.distributed.reward.reward_fn import boxed_math_reward_fn, math_reward_fn from coati.distributed.reward.verifiable_reward import VerifiableReward from coati.distributed.utils import calc_action_log_probs from coati.trainer.utils import all_reduce_mean, all_reduce_sum from transformers import AutoModelForCausalLM, AutoTokenizer from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR from colossalai.nn.optimizer import HybridAdam @ray.remote class GRPOConsumer(BaseConsumer): def __init__( self, num_producers, num_episodes, rank, world_size, master_addr, master_port, num_update_per_episode, num_recv_per_update, batch_size, model_config, plugin_config, minibatch_size=1, num_generations=8, use_wandb=True, generate_config=None, grpo_config={}, project_name=None, save_interval: int = 100, save_dir="./model", ): print(f"Using GRPO config: {grpo_config}") if grpo_config.get("loss_variation", "sample_level") == "token_level": if batch_size != minibatch_size: warnings.warn( f"Applied token_level loss, force overwrite mini-batch-size with batch-size: mini-batch-size: {minibatch_size}->{batch_size}", UserWarning, ) minibatch_size = batch_size if ( plugin_config.get("pp_size", 1) > 1 and "num_microbatches" not in plugin_config and "microbatch_size" not in plugin_config ): plugin_config["microbatch_size"] = max( 1, grpo_config.get("train_microbatch_size") // plugin_config.get("pp_size", 1) ) super().__init__( num_producers, num_episodes, rank, world_size, master_addr, master_port, num_update_per_episode, num_recv_per_update, batch_size, model_config, plugin_config, minibatch_size, save_interval=save_interval, save_dir=save_dir, ) path = model_config.pop("path") self.policy_model = AutoModelForCausalLM.from_pretrained(path, **model_config) self.orig_state_dict_key = [k for k in self.policy_model.state_dict()] self.policy_model.train() self.policy_model.gradient_checkpointing_enable() self.optimizer = HybridAdam(self.policy_model.parameters(), lr=grpo_config.get("lr", 1e-6)) self.accum_loss = torch.zeros(1, device=self.device) self.accum_reward = torch.zeros(1, device=self.device) self.accum_kl = torch.zeros(1, device=self.device) self.accum_format_acc = torch.zeros(1, device=self.device) self.accum_ans_acc = torch.zeros(1, device=self.device) self.accum_advantages = torch.zeros(1, device=self.device) self.accum_response_length = torch.zeros(1, device=self.device) self.accum_count = 0 self.generate_config = generate_config self.grpo_config = grpo_config self.project_name = project_name self.effective_sample_count = 0 self.total_sample_count = 0 self.policy_loss_fn = PolicyLoss( clip_eps_low=grpo_config.get("clip_eps_low", 0.2), clip_eps_high=grpo_config.get("clip_eps_high", 0.2), beta=grpo_config.get("beta", 0.01), loss_variation=grpo_config.get("loss_variation", "sample_level"), ) # Reference model is initialized from policy model. if self.policy_loss_fn.beta > 0: self.reference_model = AutoModelForCausalLM.from_pretrained(path, **model_config) self.reference_model.eval() self.tokenizer = AutoTokenizer.from_pretrained(path) self.pad_token_id = self.tokenizer.pad_token_id self.num_generations = num_generations self.filter_range = grpo_config.get("filter_range", None) if self.filter_range is not None: assert len(self.filter_range) == 2, "Filter range should have 2 values." self.filter_truncated_response = grpo_config.get("filter_truncated_response", False) if self.filter_truncated_response: self.max_length = 0 if "max_tokens" in self.generate_config: self.max_length = self.generate_config["max_tokens"] elif "max_new_tokens" in self.generate_config: self.max_length = self.generate_config["max_new_tokens"] else: raise ValueError( "either max_tokens (vllm) or max_new_tokens (transformers) must be set in generate_config." ) # Initialize verifiable reward. response_format_tags = { "think_start": {"text": "", "num_occur": 1}, "think_end": {"text": "", "num_occur": 1}, "answer_start": {"text": "", "num_occur": 1}, "answer_end": {"text": "", "num_occur": 1}, } reward_model_kwargs = { k: v for k, v in grpo_config.items() if k in ["soft_over_length_punishment", "max_length", "cache_length"] } self.reward_model = VerifiableReward( reward_fns=[ math_reward_fn if grpo_config.get("reward_fn_type") == "think_answer_tags" else boxed_math_reward_fn ], tokenizer=self.tokenizer, tags=response_format_tags, **reward_model_kwargs, ) self.global_step = 0 self.use_wandb = use_wandb self.lr_scheduler = CosineAnnealingWarmupLR( optimizer=self.optimizer, total_steps=min(self.num_episodes, 4) * self.num_update_per_episode, warmup_steps=0, eta_min=0.1 * grpo_config.get("lr", 1e-6), ) def get_model_state_dict_keys(self): return self.orig_state_dict_key def setup(self): super().setup() if self.use_wandb and ( (not self.plugin.pp_size > 1 and self.rank == 0) or (self.plugin.pp_size > 1 and self.booster.plugin.stage_manager.is_last_stage() and self.tp_rank == 0) ): # Initialize wandb. name = f"{self.generate_config['backend']}_bs_{self.batch_size*self.dp_size}_temp_{self.generate_config['temperature']:.01f}_top_p_{self.generate_config['top_p']:.02f}" self.wandb_run = wandb.init(project=self.project_name, sync_tensorboard=True, dir="./wandb", name=name) self.policy_model, self.optimizer, _, _, self.lr_scheduler = self.booster.boost( self.policy_model, self.optimizer, lr_scheduler=self.lr_scheduler ) if self.policy_loss_fn.beta > 0: self.reference_model, *_ = self.booster.boost(self.reference_model) self.plugin.logger.set_level("ERROR") def step(self, step_idx: int, pbar: Any, **kwargs) -> Optional[float]: """ Step data from policy model: [{ "input_ids": torch.Tensor, "attention_mask": torch.Tensor, "action_mask": torch.Tensor, "action_log_probs": torch.Tensor, }, ...] Format: [minibatch_size, num_of_generation, prompt_length + response_length] --- ............. """ # Reshape to [minibatch_size x num_of_generation, prompt_length + response_length] data = {k: v.view(-1, v.size(-1)) for k, v in kwargs.items()} action_mask = data["action_mask"] num_action = action_mask.shape[1] old_action_log_probs = data["action_log_probs"] response_length = torch.sum(action_mask, dim=1).to(torch.float32) train_microbatch_size = self.grpo_config.get("train_microbatch_size", data["input_ids"].size(0)) reward_group = self.reward_model( data["input_ids"], gt_answer=data["gt_answer"], response_idx=data["response_idx"], ) reward = torch.tensor([value[0] for value in reward_group]).to(data["input_ids"].device) format_acc = torch.tensor([value[1] for value in reward_group]).to(data["input_ids"].device) ans_acc = torch.tensor([value[2] for value in reward_group]).to(data["input_ids"].device) # [minibatch_size, num_generations] group_reward = reward.view(-1, self.num_generations) reward_mean = group_reward.mean(dim=1) # [minibatch_size x num_generations] reward_mean = reward_mean.repeat_interleave(self.num_generations, dim=0) reward_std = group_reward.std(dim=1).repeat_interleave(self.num_generations, dim=0) # [minibatch_size x num_generations] advantages = ((reward - reward_mean) / (reward_std + 1e-4)).unsqueeze(dim=-1) # filter out the reward that is too high (all sample gets full score) or too low (all sample gets 0 score), group_ans_acc = ( ans_acc.view(-1, self.num_generations).mean(dim=1).repeat_interleave(self.num_generations, dim=0) ) loss_mask = ( torch.ones(action_mask.size(0), device=action_mask.device).bool() if self.filter_range is None else torch.logical_and(group_ans_acc > self.filter_range[0], group_ans_acc < self.filter_range[1]) ) # filter out overlength samples if self.filter_truncated_response and action_mask.size(1) == self.max_length: loss_mask = torch.logical_and( loss_mask, action_mask[:, -1] == False, ) effective_tokens_count = torch.sum(action_mask, dim=-1) * loss_mask effective_samples = all_reduce_sum(torch.sum(loss_mask), self.plugin) total_effective_tokens_count = all_reduce_sum(torch.sum(effective_tokens_count), self.plugin) total_samples = all_reduce_sum(torch.sum(torch.ones_like(loss_mask, device=loss_mask.device)), self.plugin) self.effective_sample_count += effective_samples.item() self.total_sample_count += total_samples.item() mean_kl, mean_loss = [], [] if self.grpo_config.get("dynamic_batching", True): need_update = self.effective_sample_count >= self.batch_size * self.dp_size * self.num_generations # to exclude the excessive samples from the last batch, the last num_excessive_samples samples are not used for training and will be kept in buffer for the next iteration. num_excessive_samples = ( int( (self.effective_sample_count - self.batch_size * self.dp_size * self.num_generations) / self.num_generations / self.dp_size ) * self.num_generations ) if num_excessive_samples > 0: data = { k: ( v[: -num_excessive_samples if num_excessive_samples != 0 else v.size(0)] if k in [ "input_ids", "attention_mask", "action_log_probs", "action_mask", "response_idx", "gt_answer", ] else v ) for k, v in data.items() } action_mask = action_mask[ : -num_excessive_samples if num_excessive_samples != 0 else action_mask.size(0) ] loss_mask = loss_mask[: -num_excessive_samples if num_excessive_samples != 0 else loss_mask.size(0)] advantages = advantages[: -num_excessive_samples if num_excessive_samples != 0 else advantages.size(0)] else: num_excessive_samples = 0 else: # If dynamic batching is disabled, we need to use all samples for training. need_update = (step_idx + 1) % self.num_microbatches == 0 num_excessive_samples = 0 pbar.set_postfix( { "Step": self.global_step + 1, "Status": f"Collecting: {self.effective_sample_count}/{self.batch_size * self.dp_size * self.num_generations}", } ) # Gradient must be synchronized if zero2 is enabled. https://github.com/hpcaitech/ColossalAI/blob/44d4053fec005fe0b06b6bc755fdc962463145df/colossalai/booster/plugin/hybrid_parallel_plugin.py#L1500 ctx = ( nullcontext() if need_update or self.booster.plugin.zero_stage == 2 else self.booster.no_sync(self.policy_model, self.optimizer) ) with ctx: for forward_micro_batch_start in range(0, data["input_ids"].size(0), train_microbatch_size): input_ids_forward_micro_batch = data["input_ids"][ forward_micro_batch_start : forward_micro_batch_start + train_microbatch_size ] attention_mask_forward_micro_batch = data["attention_mask"][ forward_micro_batch_start : forward_micro_batch_start + train_microbatch_size ] action_mask_forward_micro_batch = action_mask[ forward_micro_batch_start : forward_micro_batch_start + train_microbatch_size ] loss_mask_forward_micro_batch = ( loss_mask[forward_micro_batch_start : forward_micro_batch_start + train_microbatch_size] if loss_mask is not None else None ) advantages_forward_micro_batch = advantages[ forward_micro_batch_start : forward_micro_batch_start + train_microbatch_size ] if self.plugin.pp_size > 1: # Support training with PP. if self.policy_loss_fn.beta > 0: with torch.no_grad(): reference_model_outputs = self.booster.execute_pipeline( iter( [ { "input_ids": input_ids_forward_micro_batch, "attention_mask": attention_mask_forward_micro_batch, } ] ), self.reference_model, criterion=lambda outputs, inputs: torch.tensor( [0.0], device=action_mask.device ), # dummy criterion optimizer=None, return_loss=False, return_outputs=True, ) if self.booster.plugin.stage_manager.is_last_stage(): reference_model_logits = reference_model_outputs["outputs"]["logits"] reference_action_log_probs = calc_action_log_probs( reference_model_logits / self.generate_config["temperature"], input_ids_forward_micro_batch, num_action, self.plugin.shard_config, ) else: # Dummy reference logprobs for data iterator. reference_action_log_probs = None else: reference_action_log_probs = None data_policy_forward = { "input_ids": input_ids_forward_micro_batch, "attention_mask": attention_mask_forward_micro_batch, "action_mask": action_mask_forward_micro_batch, "advantages": advantages_forward_micro_batch, "loss_mask": loss_mask_forward_micro_batch, "source": self.rank, } if reference_action_log_probs is not None: data_policy_forward["reference_action_log_probs"] = reference_action_log_probs kl = [] def _criterion(outputs, inputs): action_logits = outputs.logits action_log_probs = calc_action_log_probs( action_logits / self.generate_config["temperature"], inputs["input_ids"], num_action, self.plugin.shard_config, ) if "reference_action_log_probs" in inputs: per_token_kl = ( torch.exp(inputs["reference_action_log_probs"] - action_log_probs) - (inputs["reference_action_log_probs"] - action_log_probs) - 1 ) appox_kl = torch.sum(per_token_kl * inputs["action_mask"], dim=-1) / torch.sum( inputs["action_mask"], dim=-1 ) kl.append(appox_kl.mean()) else: per_token_kl = 0.0 kl.append(0.0) loss, _ = self.policy_loss_fn( action_log_probs, action_log_probs, inputs["advantages"].repeat_interleave(action_log_probs.size(-1), dim=-1), per_token_kl, inputs["action_mask"], loss_mask=inputs["loss_mask"], total_effective_tokens_in_batch=total_effective_tokens_count, ) return loss policy_model_outputs = self.booster.execute_pipeline( iter([data_policy_forward]), self.policy_model, criterion=_criterion, optimizer=self.optimizer, return_loss=True, return_outputs=True, ) loss = policy_model_outputs["loss"] if self.booster.plugin.stage_manager.is_last_stage(): if len(kl) > 0: kl = all_reduce_mean(torch.mean(torch.stack(kl)).to(loss.device), self.plugin).data mean_kl.append(kl) mean_loss.append(all_reduce_mean(loss, self.plugin).data) else: policy_model_logits = self.policy_model( input_ids=input_ids_forward_micro_batch, attention_mask=attention_mask_forward_micro_batch, ).logits action_log_probs = calc_action_log_probs( policy_model_logits / self.generate_config["temperature"], input_ids_forward_micro_batch, num_action, self.plugin.shard_config, ) if self.policy_loss_fn.beta > 0: with torch.no_grad(): reference_model_logits = self.reference_model( input_ids=input_ids_forward_micro_batch, attention_mask=attention_mask_forward_micro_batch, ).logits reference_action_log_probs = calc_action_log_probs( reference_model_logits / self.generate_config["temperature"], input_ids_forward_micro_batch, num_action, self.plugin.shard_config, ) per_token_kl = ( torch.exp(reference_action_log_probs - action_log_probs) - (reference_action_log_probs - action_log_probs) - 1 ) kl = torch.sum(per_token_kl * action_mask_forward_micro_batch, dim=-1) / torch.sum( action_mask_forward_micro_batch, dim=-1 ) else: per_token_kl = 0.0 kl = None loss, _ = self.policy_loss_fn( action_log_probs, old_action_log_probs, advantages_forward_micro_batch.repeat_interleave(action_log_probs.size(-1), dim=-1), per_token_kl, action_mask_forward_micro_batch, loss_mask=loss_mask_forward_micro_batch, total_effective_tokens_in_batch=total_effective_tokens_count, ) self.booster.backward(loss, self.optimizer) loss = all_reduce_mean(loss, self.plugin) # Calculate accumulate value. if kl is not None: kl = all_reduce_mean(kl.mean(), self.plugin) mean_kl.append(kl.data) mean_loss.append(loss.data) if not self.plugin.pp_size > 1 or ( self.plugin.pp_size > 1 and self.booster.plugin.stage_manager.is_last_stage() and self.tp_rank == 0 ): reward = all_reduce_mean(reward.mean(), self.plugin) format_acc = all_reduce_mean(format_acc.mean(), self.plugin) ans_acc = all_reduce_mean(ans_acc.mean(), self.plugin) advantages = all_reduce_mean(advantages.mean(), self.plugin) response_length = all_reduce_mean(response_length.mean(), self.plugin) self.accum_loss.add_(sum(mean_loss) / len(mean_loss)) if self.policy_loss_fn.beta > 0: self.accum_kl.add_(sum(mean_kl) / len(mean_kl)) self.accum_reward.add_(reward.data) self.accum_format_acc.add_(format_acc.data) self.accum_ans_acc.add_(ans_acc.data) self.accum_advantages.add_(advantages.data) self.accum_response_length.add_(response_length.data) self.accum_count += 1 if need_update: self.optimizer.step() self.optimizer.zero_grad() self.global_step += 1 sample_utilization = self.effective_sample_count / self.total_sample_count self.effective_sample_count = 0 self.total_sample_count = 0 loss_scalar = self.accum_loss.item() if not self.plugin.pp_size > 1 or ( self.plugin.pp_size > 1 and self.booster.plugin.stage_manager.is_last_stage() and self.tp_rank == 0 ): if (not self.plugin.pp_size > 1 and self.rank == 0) or ( self.plugin.pp_size > 1 and self.booster.plugin.stage_manager.is_last_stage() and self.tp_rank == 0 ): to_log_msg = [ f"Loss: {self.accum_loss.item() / self.accum_count:.4f}", f"Reward: {self.accum_reward.item() / self.accum_count:.4f}", f"format Reward: {self.accum_format_acc.item() / self.accum_count:.4f}", f"Acc Reward: {self.accum_ans_acc.item() / self.accum_count:.4f}", f"Advantages: {self.accum_advantages.item() / self.accum_count:.4f}", f"Response Length: {self.accum_response_length.item() / self.accum_count:.4f}", ] + ([f"KL: {self.accum_kl.item() / self.accum_count:.4f}"] if self.policy_loss_fn.beta > 0 else []) print("\n".join(to_log_msg)) metrics = { "metrics/reward": self.accum_reward.item() / self.accum_count, "metrics/format_acc": self.accum_format_acc.item() / self.accum_count, "metrics/ans_acc": self.accum_ans_acc.item() / self.accum_count, "metrics/response_length": self.accum_response_length.item() / self.accum_count, "train/loss": self.accum_loss.item() / self.accum_count, "train/advantages": self.accum_advantages.item() / self.accum_count, "train/learning_rate": self.lr_scheduler.get_last_lr()[0], "train/sample_utilization": sample_utilization, "rollout/temperature": data["temperature"].cpu().numpy()[0][0], } if self.policy_loss_fn.beta > 0: metrics["train/kl"] = self.accum_kl.item() / self.accum_count self.wandb_run.log(metrics) self.accum_loss.zero_() self.accum_reward.zero_() self.accum_ans_acc.zero_() self.accum_format_acc.zero_() self.accum_kl.zero_() self.accum_advantages.zero_() self.accum_response_length.zero_() self.accum_count = 0 return loss_scalar, num_excessive_samples // self.num_generations else: return None, num_excessive_samples // self.num_generations def state_dict(self): self.policy_model._force_wait_all_gather() model = self.policy_model.unwrap() state_dict = model.state_dict() return state_dict