diff --git a/applications/ColossalChat/coati/distributed/grpo_consumer.py b/applications/ColossalChat/coati/distributed/grpo_consumer.py index e23254d1b..5cd97b20a 100644 --- a/applications/ColossalChat/coati/distributed/grpo_consumer.py +++ b/applications/ColossalChat/coati/distributed/grpo_consumer.py @@ -12,7 +12,7 @@ from coati.distributed.loss import PolicyLoss from coati.distributed.reward.reward_fn import math_reward_fn from coati.distributed.reward.verifiable_reward import VerifiableReward from coati.distributed.utils import calc_action_log_probs -from coati.trainer.utils import all_reduce_mean +from coati.trainer.utils import all_reduce_mean, all_reduce_sum from transformers import AutoModelForCausalLM, AutoTokenizer from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR @@ -38,7 +38,7 @@ class GRPOConsumer(BaseConsumer): num_generations=8, use_wandb=True, generate_config=None, - training_config={}, + grpo_config={}, project_name=None, ): super().__init__( @@ -59,7 +59,7 @@ class GRPOConsumer(BaseConsumer): self.policy_model = AutoModelForCausalLM.from_pretrained(path, **model_config) self.policy_model.train() self.policy_model.gradient_checkpointing_enable() - self.optimizer = HybridAdam(self.policy_model.parameters(), lr=training_config.get("lr", 1e-6)) + self.optimizer = HybridAdam(self.policy_model.parameters(), lr=grpo_config.get("lr", 1e-6)) self.accum_loss = torch.zeros(1, device=self.device) self.accum_reward = torch.zeros(1, device=self.device) self.accum_kl = torch.zeros(1, device=self.device) @@ -69,8 +69,9 @@ class GRPOConsumer(BaseConsumer): self.accum_response_length = torch.zeros(1, device=self.device) self.accum_count = 0 self.generate_config = generate_config - self.training_config = training_config + self.grpo_config = grpo_config self.project_name = project_name + self.effective_sample_count = 0 # Reference model is initialized from policy model. self.reference_model = AutoModelForCausalLM.from_pretrained(path, **model_config) @@ -79,10 +80,21 @@ class GRPOConsumer(BaseConsumer): self.tokenizer = AutoTokenizer.from_pretrained(path) self.pad_token_id = self.tokenizer.pad_token_id self.num_generations = num_generations - self.filter_range = training_config.get("filter_range", None) + self.filter_range = grpo_config.get("filter_range", None) if self.filter_range is not None: assert len(self.filter_range) == 2, "Filter range should have 2 values." + self.filter_truncated_response = grpo_config.get("filter_truncated_response", False) + if self.filter_truncated_response: + self.max_length = 0 + if "max_tokens" in self.generate_config: + self.max_length = self.generate_config["max_tokens"] + elif "max_new_tokens" in self.generate_config: + self.max_length = self.generate_config["max_new_tokens"] + else: + raise ValueError( + "either max_tokens (vllm) or max_new_tokens (transformers) must be set in generate_config." + ) # Initialize verifiable reward. response_format_tags = { "think_start": {"text": "", "num_occur": 1}, @@ -90,11 +102,20 @@ class GRPOConsumer(BaseConsumer): "answer_start": {"text": "", "num_occur": 1}, "answer_end": {"text": "", "num_occur": 1}, } + reward_model_kwargs = { + k: v for k, v in grpo_config.items() if k in ["soft_over_length_punishment", "max_length", "cache_length"] + } self.reward_model = VerifiableReward( - reward_fns=[math_reward_fn], tokenizer=self.tokenizer, tags=response_format_tags + reward_fns=[math_reward_fn], tokenizer=self.tokenizer, tags=response_format_tags, **reward_model_kwargs ) - self.policy_loss_fn = PolicyLoss() + self.policy_loss_fn = PolicyLoss( + clip_eps_low=grpo_config.get("clip_eps_low", 0.2), + clip_eps_high=grpo_config.get("clip_eps_high", 0.2), + skip_threshold=grpo_config.get("skip_threshold", 20.0), + beta=grpo_config.get("beta", 0.01), + loss_variation=grpo_config.get("loss_variation", "sample_level"), + ) self.global_step = 0 self.use_wandb = use_wandb @@ -102,7 +123,7 @@ class GRPOConsumer(BaseConsumer): optimizer=self.optimizer, total_steps=min(self.num_episodes, 4) * self.num_update_per_episode, warmup_steps=0, - eta_min=0.1 * training_config.get("lr", 1e-6), + eta_min=0.1 * grpo_config.get("lr", 1e-6), ) def setup(self): @@ -141,9 +162,65 @@ class GRPOConsumer(BaseConsumer): num_action = action_mask.shape[1] old_action_log_probs = data["action_log_probs"] response_length = torch.sum(action_mask, dim=1).to(torch.float32) - forward_batch_size = self.training_config.get("train_microbatch_size", data["input_ids"].size(0)) + forward_batch_size = self.grpo_config.get("train_microbatch_size", data["input_ids"].size(0)) + + reward_group = self.reward_model( + int(step_idx / self.num_microbatches), + data["input_ids"], + gt_answer=data["gt_answer"], + response_idx=data["response_idx"], + ) + + reward = torch.tensor([value[0] for value in reward_group]).to(data["input_ids"].device) + format_reward = torch.tensor([value[1] for value in reward_group]).to(data["input_ids"].device) + acc_reward = torch.tensor([value[2] for value in reward_group]).to(data["input_ids"].device) + + # [batch_size, num_generations] + + group_reward = reward.view(-1, self.num_generations) + reward_mean = group_reward.mean(dim=1) + # [batch_size x num_generations] + reward_mean = reward_mean.repeat_interleave(self.num_generations, dim=0) + + reward_std = group_reward.std(dim=1).repeat_interleave(self.num_generations, dim=0) + # [batch_size x num_generations] + advantages = ((reward - reward_mean) / (reward_std + 1e-4)).unsqueeze(dim=-1) + # filter out the reward that is too high (all sample gets full score) or too low (all sample gets 0 score), + reward_mean_no_length_penalty = ( + (format_reward + acc_reward) + .view(-1, self.num_generations) + .mean(dim=1) + .repeat_interleave(self.num_generations, dim=0) + ) + loss_mask = ( + torch.ones(action_mask.size(0), device=action_mask.device).bool() + if self.filter_range is None + else torch.logical_and( + reward_mean_no_length_penalty > self.filter_range[0], reward_mean < self.filter_range[1] + ) + ) + # filter out overlength samples + if self.filter_truncated_response and action_mask.size(1) == self.max_length: + loss_mask = torch.logical_and( + loss_mask, + action_mask[:, -1] == False, + ) + # for i in range(loss_mask.size(0)): + # if loss_mask[i] == False: + # print(data["input_ids"].size(), data["input_ids"][i], action_mask[i], "mean reward", reward_mean_no_length_penalty.size(), reward_mean_no_length_penalty[i]) + + effective_samples = all_reduce_sum(torch.sum(loss_mask), self.plugin) + self.effective_sample_count += effective_samples.item() + + mean_kl, mean_loss = [], [] + + # update gradient only if at least 0.7*batch_size*num_generation valid samples are collected in case a lot of samples are invalid and got filtered out. + # balance between efficiency and accuracy + need_update = self.effective_sample_count >= self.batch_size * self.dp_size * self.num_generations * 0.75 + if need_update: + print(f"***** Update gradient based on {self.effective_sample_count} valid samples *****") + self.effective_sample_count = 0 - need_update = (step_idx + 1) % self.num_microbatches == 0 # Gradient must be synchronized if zero2 is enabled. https://github.com/hpcaitech/ColossalAI/blob/44d4053fec005fe0b06b6bc755fdc962463145df/colossalai/booster/plugin/hybrid_parallel_plugin.py#L1500 ctx = ( nullcontext() @@ -151,32 +228,6 @@ class GRPOConsumer(BaseConsumer): else self.booster.no_sync(self.policy_model, self.optimizer) ) with ctx: - reward_group = self.reward_model( - data["input_ids"], gt_answer=data["gt_answer"], response_idx=data["response_idx"] - ) - - reward = torch.tensor([value[0] for value in reward_group]).to(data["input_ids"].device) - format_reward = torch.tensor([value[1] for value in reward_group]).to(data["input_ids"].device) - acc_reward = torch.tensor([value[2] for value in reward_group]).to(data["input_ids"].device) - - # [batch_size, num_generations] - - group_reward = reward.view(-1, self.num_generations) - reward_mean = group_reward.mean(dim=1) - # [batch_size x num_generations] - reward_mean = reward_mean.repeat_interleave(self.num_generations, dim=0) - reward_std = group_reward.std(dim=1).repeat_interleave(self.num_generations, dim=0) - # [batch_size x num_generations] - advantages = ((reward - reward_mean) / (reward_std + 1e-4)).unsqueeze(dim=-1) - # filter out the reward that is too high (all sample gets full score) or too low (all sample gets 0 score), - loss_mask = ( - None - if self.filter_range is None - else torch.logical_and( - reward_mean > self.filter_range[0], reward_mean < self.filter_range[1] - ).repeat_interleave(self.num_generations, dim=0) - ) - mean_kl, mean_loss = [], [] for forward_micro_batch_start in range(0, data["input_ids"].size(0), forward_batch_size): input_ids_forward_micro_batch = data["input_ids"][ @@ -199,47 +250,50 @@ class GRPOConsumer(BaseConsumer): if self.plugin.pp_size > 1: # Support training with PP. + if self.policy_loss_fn.beta > 0: + with torch.no_grad(): + reference_model_outputs = self.booster.execute_pipeline( + iter( + [ + { + "input_ids": input_ids_forward_micro_batch, + "attention_mask": attention_mask_forward_micro_batch, + } + ] + ), + self.reference_model, + criterion=lambda outputs, inputs: torch.tensor( + [0.0], device=action_mask.device + ), # dummy criterion + optimizer=None, + return_loss=False, + return_outputs=True, + ) - with torch.no_grad(): - reference_model_outputs = self.booster.execute_pipeline( - iter( - [ - { - "input_ids": input_ids_forward_micro_batch, - "attention_mask": attention_mask_forward_micro_batch, - } - ] - ), - self.reference_model, - criterion=lambda outputs, inputs: torch.tensor( - [0.0], device=action_mask.device - ), # dummy criterion - optimizer=None, - return_loss=False, - return_outputs=True, - ) - - if self.booster.plugin.stage_manager.is_last_stage(): - reference_model_logits = reference_model_outputs["outputs"]["logits"] - reference_action_log_probs = calc_action_log_probs( - reference_model_logits / self.generate_config["temperature"], - input_ids_forward_micro_batch, - num_action, - self.plugin.shard_config, - ) + if self.booster.plugin.stage_manager.is_last_stage(): + reference_model_logits = reference_model_outputs["outputs"]["logits"] + reference_action_log_probs = calc_action_log_probs( + reference_model_logits / self.generate_config["temperature"], + input_ids_forward_micro_batch, + num_action, + self.plugin.shard_config, + ) + else: + # Dummy reference logprobs for data iterator. + reference_action_log_probs = None else: - # Dummy reference logprobs for data iterator. reference_action_log_probs = None data_policy_forward = { "input_ids": input_ids_forward_micro_batch, "attention_mask": attention_mask_forward_micro_batch, "action_mask": action_mask_forward_micro_batch, - "reference_action_log_probs": reference_action_log_probs, "advantages": advantages_forward_micro_batch, "loss_mask": loss_mask_forward_micro_batch, "source": self.rank, } + if reference_action_log_probs is not None: + data_policy_forward["reference_action_log_probs"] = reference_action_log_probs kl = [] @@ -251,15 +305,20 @@ class GRPOConsumer(BaseConsumer): num_action, self.plugin.shard_config, ) - per_token_kl = ( - torch.exp(inputs["reference_action_log_probs"] - action_log_probs) - - (inputs["reference_action_log_probs"] - action_log_probs) - - 1 - ) - appox_kl = torch.sum(per_token_kl * inputs["action_mask"], dim=-1) / torch.sum( - inputs["action_mask"], dim=-1 - ) - kl.append(appox_kl.mean()) + if "reference_action_log_probs" in inputs: + per_token_kl = ( + torch.exp(inputs["reference_action_log_probs"] - action_log_probs) + - (inputs["reference_action_log_probs"] - action_log_probs) + - 1 + ) + appox_kl = torch.sum(per_token_kl * inputs["action_mask"], dim=-1) / torch.sum( + inputs["action_mask"], dim=-1 + ) + kl.append(appox_kl.mean()) + else: + per_token_kl = 0.0 + kl.append(0.0) + loss, skip_update, _ = self.policy_loss_fn( action_log_probs, action_log_probs, @@ -298,25 +357,29 @@ class GRPOConsumer(BaseConsumer): self.plugin.shard_config, ) - with torch.no_grad(): - reference_model_logits = self.reference_model( - input_ids=input_ids_forward_micro_batch, - attention_mask=attention_mask_forward_micro_batch, - ).logits - reference_action_log_probs = calc_action_log_probs( - reference_model_logits / self.generate_config["temperature"], - input_ids_forward_micro_batch, - num_action, - self.plugin.shard_config, - ) - per_token_kl = ( - torch.exp(reference_action_log_probs - action_log_probs) - - (reference_action_log_probs - action_log_probs) - - 1 - ) - kl = torch.sum(per_token_kl * action_mask_forward_micro_batch, dim=-1) / torch.sum( - action_mask_forward_micro_batch, dim=-1 - ) + if self.policy_loss_fn.beta > 0: + with torch.no_grad(): + reference_model_logits = self.reference_model( + input_ids=input_ids_forward_micro_batch, + attention_mask=attention_mask_forward_micro_batch, + ).logits + reference_action_log_probs = calc_action_log_probs( + reference_model_logits / self.generate_config["temperature"], + input_ids_forward_micro_batch, + num_action, + self.plugin.shard_config, + ) + per_token_kl = ( + torch.exp(reference_action_log_probs - action_log_probs) + - (reference_action_log_probs - action_log_probs) + - 1 + ) + kl = torch.sum(per_token_kl * action_mask_forward_micro_batch, dim=-1) / torch.sum( + action_mask_forward_micro_batch, dim=-1 + ) + else: + per_token_kl = 0.0 + kl = None loss, skip_update, _ = self.policy_loss_fn( action_log_probs, @@ -330,9 +393,10 @@ class GRPOConsumer(BaseConsumer): if not skip_update: self.booster.backward(loss, self.optimizer) loss = all_reduce_mean(loss, self.plugin) - kl = all_reduce_mean(kl.mean(), self.plugin) # Calculate accumulate value. - mean_kl.append(kl.data) + if kl is not None: + kl = all_reduce_mean(kl.mean(), self.plugin) + mean_kl.append(kl.data) mean_loss.append(loss.data) if not self.plugin.pp_size > 1 or ( self.plugin.pp_size > 1 and self.booster.plugin.stage_manager.is_last_stage() and self.tp_rank == 0 @@ -343,7 +407,8 @@ class GRPOConsumer(BaseConsumer): advantages = all_reduce_mean(advantages.mean(), self.plugin) response_length = all_reduce_mean(response_length.mean(), self.plugin) self.accum_loss.add_(sum(mean_loss) / len(mean_loss)) - self.accum_kl.add_(sum(mean_kl) / len(mean_kl)) + if self.policy_loss_fn.beta > 0: + self.accum_kl.add_(sum(mean_kl) / len(mean_kl)) self.accum_reward.add_(reward.data) self.accum_format_reward.add_(format_reward.data) self.accum_acc_reward.add_(acc_reward.data) @@ -360,35 +425,32 @@ class GRPOConsumer(BaseConsumer): if (not self.plugin.pp_size > 1 and self.rank == 0) or ( self.plugin.pp_size > 1 and self.booster.plugin.stage_manager.is_last_stage() and self.tp_rank == 0 ): - print( - "Loss:", - self.accum_loss.item() / self.accum_count, - "\nReward:", - self.accum_reward.item() / self.accum_count, - "\nFormat Reward:", - self.accum_format_reward.item() / self.accum_count, - "\nAcc Reward:", - self.accum_acc_reward.item() / self.accum_count, - "\nKL:", - self.accum_kl.item() / self.accum_count, - "\nAdvantages:", - self.accum_advantages.item() / self.accum_count, - "\nResponse Length:", - self.accum_response_length.item() / self.accum_count, - ) - self.wandb_run.log( - { - "metrics/reward": self.accum_reward.item() / self.accum_count, - "metrics/format_reward": self.accum_format_reward.item() / self.accum_count, - "metrics/acc_reward": self.accum_acc_reward.item() / self.accum_count, - "metrics/response_length": self.accum_response_length.item() / self.accum_count, - "train/loss": self.accum_loss.item() / self.accum_count, - "train/kl": self.accum_kl.item() / self.accum_count, - "train/advantages": self.accum_advantages.item() / self.accum_count, - "train/learning_rate": self.lr_scheduler.get_last_lr()[0], - "rollout/temperature": data["temperature"].cpu().numpy()[0][0], - } + to_log_msg = ( + f"Loss: {self.accum_loss.item() / self.accum_count:.4f} \ + Reward: {self.accum_reward.item() / self.accum_count:.4f} \ + Format Reward: {self.accum_format_reward.item() / self.accum_count:.4f} \ + Acc Reward: {self.accum_acc_reward.item() / self.accum_count:.4f} \ + Advantages: {self.accum_advantages.item() / self.accum_count:.4f} \ + Response Length: {self.accum_response_length.item() / self.accum_count:.4f}" + + f" KL: {self.accum_kl.item() / self.accum_count:.4f}" + if self.policy_loss_fn.beta > 0 + else "" ) + print(to_log_msg) + metrics = { + "metrics/reward": self.accum_reward.item() / self.accum_count, + "metrics/format_reward": self.accum_format_reward.item() / self.accum_count, + "metrics/acc_reward": self.accum_acc_reward.item() / self.accum_count, + "metrics/response_length": self.accum_response_length.item() / self.accum_count, + "train/loss": self.accum_loss.item() / self.accum_count, + "train/advantages": self.accum_advantages.item() / self.accum_count, + "train/learning_rate": self.lr_scheduler.get_last_lr()[0], + "rollout/temperature": data["temperature"].cpu().numpy()[0][0], + } + if self.policy_loss_fn.beta > 0: + metrics["train/kl"] = self.accum_kl.item() / self.accum_count + + self.wandb_run.log(metrics) self.accum_loss.zero_() self.accum_reward.zero_() self.accum_acc_reward.zero_() diff --git a/applications/ColossalChat/coati/distributed/launch.py b/applications/ColossalChat/coati/distributed/launch.py index 699d90a8c..8936752d2 100644 --- a/applications/ColossalChat/coati/distributed/launch.py +++ b/applications/ColossalChat/coati/distributed/launch.py @@ -40,6 +40,7 @@ def launch_distributed( inference_model_config: Dict[str, Any], generate_config: Dict[str, Any], train_model_config: Dict[str, Any], + grpo_config: Dict[str, Any], plugin_config: Dict[str, Any], tokenizer_config: Optional[Dict[str, Any]] = None, inference_backend: str = "transformers", @@ -103,11 +104,7 @@ def launch_distributed( plugin_config=plugin_config, microbatch_size=train_minibatch_size, generate_config=generate_config_consumer, - training_config={ - "filter_range": [0.05, 9.0], - "lr": 1e-6, - "train_microbatch_size": train_microbatch_size, - }, + grpo_config=grpo_config, num_generations=num_generations, project_name=project_name, ) diff --git a/applications/ColossalChat/coati/distributed/loss.py b/applications/ColossalChat/coati/distributed/loss.py index 90ad09736..bdbe64a2a 100644 --- a/applications/ColossalChat/coati/distributed/loss.py +++ b/applications/ColossalChat/coati/distributed/loss.py @@ -2,7 +2,7 @@ from typing import Optional import torch import torch.nn as nn -from coati.distributed.utils import masked_mean +from coati.distributed.utils import masked_mean, masked_sum class PolicyLoss(nn.Module): @@ -10,11 +10,21 @@ class PolicyLoss(nn.Module): Policy Loss for PPO """ - def __init__(self, clip_eps: float = 0.2, skip_threshold: float = 20.0, beta: float = 0.01) -> None: + def __init__( + self, + clip_eps_low: float = 0.2, + clip_eps_high: float = 0.2, + skip_threshold: float = 20.0, + beta: float = 0.01, + loss_variation: str = "sample_level", + ) -> None: super().__init__() - self.clip_eps = clip_eps + self.clip_eps_low = clip_eps_low + self.clip_eps_high = clip_eps_high self.skip_threshold = skip_threshold self.beta = beta + self.loss_variation = loss_variation + assert loss_variation in ["sample_level", "token_level"], f"Unsupported loss variation: {loss_variation}" def forward( self, @@ -32,14 +42,31 @@ class PolicyLoss(nn.Module): ratio = ((log_probs - log_probs.detach()) * action_mask).exp() surr1 = ratio * advantages - surr2 = ratio.clamp(1 - self.clip_eps, 1 + self.clip_eps) * advantages + surr2 = ratio.clamp(1 - self.clip_eps_low, 1 + self.clip_eps_high) * advantages + if self.beta <= 0: + # skip kl term if kl coefficient is zero + per_token_kl = 0.0 loss = -torch.min(surr1, surr2) + self.beta * per_token_kl - if action_mask is not None: - loss = masked_mean(loss, action_mask) - else: - loss = loss.mean(dim=1) - if loss_mask is not None: - loss = loss * loss_mask - loss = loss.mean() + if self.loss_variation == "sample_level": + if action_mask is not None: + loss = masked_mean(loss, action_mask) + else: + loss = loss.mean(dim=1) + if loss_mask is not None: + loss = loss * loss_mask + loss = loss.mean() + elif self.loss_variation == "token_level": + total_tokens = 0 + if action_mask is not None: + loss = masked_sum(loss, action_mask) + total_tokens = action_mask.sum(dim=1) + else: + loss = loss.sum(dim=1) + total_tokens = torch.ones_like(loss, device=loss.device) * log_probs.size(1) + if loss_mask is not None: + loss = loss * loss_mask + total_tokens = total_tokens * loss_mask + loss = loss.sum() / (total_tokens.sum() + 1e-8) + return loss, skip, ratio.max() diff --git a/applications/ColossalChat/coati/distributed/producer.py b/applications/ColossalChat/coati/distributed/producer.py index 2c6a24a36..c5681b9b5 100644 --- a/applications/ColossalChat/coati/distributed/producer.py +++ b/applications/ColossalChat/coati/distributed/producer.py @@ -124,12 +124,12 @@ class BaseProducer: self.load_state_dict(state_dict) del state_dict torch.cuda.empty_cache() - # linear annealing for 1 episode, temperature from initial to 0.7 + # linear annealing for 1 episode, temperature from initial to 0.9 if episode <= 0: ratio = 1 - (len(self.dataloader) - i) / len(self.dataloader) self.model.generate_config.temperature = (1 - ratio) * self.generate_config[ "temperature" - ] + ratio * 0.7 + ] + ratio * 0.9 @ray.remote diff --git a/applications/ColossalChat/coati/distributed/reward/reward_fn.py b/applications/ColossalChat/coati/distributed/reward/reward_fn.py index 53bc15e25..a0f92d8c4 100644 --- a/applications/ColossalChat/coati/distributed/reward/reward_fn.py +++ b/applications/ColossalChat/coati/distributed/reward/reward_fn.py @@ -3,14 +3,29 @@ import torch from .reward_utils import extract_solution, validate_response_structure -def math_reward_fn(input_ids, gt_answer, response_idx, **kwargs): +def math_reward_fn(step, input_ids, gt_answer, response_idx, **kwargs): + tokenizer = kwargs["tokenizer"] + soft_over_length_punishment = kwargs["soft_over_length_punishment"] format_score = 1.0 acc_score = 9.0 - tokenizer = kwargs["tokenizer"] + if step > 30: + format_score = 0.0 + acc_score = 10.0 reward = torch.tensor(0.0) format_reward = torch.tensor(0.0) acc_reward = torch.tensor(0.0) s, e = response_idx[0], response_idx[1] + + length_reward = 0.0 + if soft_over_length_punishment: + max_length = kwargs.get("max_length", 1024 * 4) + cache_length = kwargs.get("cache_length", 512) + res_length = e.item() - s.item() + 1 + if res_length >= max_length: + length_reward = -1.0 * 2 + elif res_length > max_length - cache_length: + length_reward = ((max_length - cache_length) - res_length) / cache_length * 2 + if gt_answer is None: return reward @@ -33,6 +48,8 @@ def math_reward_fn(input_ids, gt_answer, response_idx, **kwargs): acc_reward += acc_score reward += acc_score + reward = reward + length_reward + return torch.tensor([reward, format_reward, acc_reward]).to(input_ids.device) diff --git a/applications/ColossalChat/coati/distributed/reward/verifiable_reward.py b/applications/ColossalChat/coati/distributed/reward/verifiable_reward.py index ba83f7787..01d8f1663 100644 --- a/applications/ColossalChat/coati/distributed/reward/verifiable_reward.py +++ b/applications/ColossalChat/coati/distributed/reward/verifiable_reward.py @@ -14,6 +14,7 @@ class VerifiableReward: def __call__( self, + step: int, input_ids: torch.LongTensor, gt_answer: List[torch.Tensor] = None, response_idx: List[torch.Tensor] = None, @@ -29,6 +30,7 @@ class VerifiableReward: reward_batch = torch.stack( [ reward_fn( + step, input_ids[i], gt_answer=gt_answer[i], response_idx=response_idx[i], diff --git a/applications/ColossalChat/coati/distributed/utils.py b/applications/ColossalChat/coati/distributed/utils.py index 919e4434f..5f7879669 100644 --- a/applications/ColossalChat/coati/distributed/utils.py +++ b/applications/ColossalChat/coati/distributed/utils.py @@ -113,3 +113,20 @@ def masked_mean(tensor: torch.Tensor, mask: torch.Tensor, dim: int = 1) -> torch mask_sum = mask.sum(dim=dim) mean = tensor / (mask_sum + 1e-8) return mean + + +def masked_sum(tensor: torch.Tensor, mask: torch.Tensor, dim: int = 1) -> torch.Tensor: + """ + Compute the masked sum of a tensor along a specified dimension. + + Args: + tensor (torch.Tensor): The input tensor. + mask (torch.Tensor): The mask tensor with the same shape as the input tensor. + dim (int, optional): The dimension along which to compute the sum. Default is 1. + + Returns: + torch.Tensor: The masked sum tensor. + + """ + tensor = tensor * mask + return tensor.sum(dim=dim) diff --git a/applications/ColossalChat/coati/trainer/utils.py b/applications/ColossalChat/coati/trainer/utils.py index 22a5f492e..45dfab588 100755 --- a/applications/ColossalChat/coati/trainer/utils.py +++ b/applications/ColossalChat/coati/trainer/utils.py @@ -128,7 +128,21 @@ def all_reduce_mean(tensor: torch.Tensor, plugin: Plugin = None) -> torch.Tensor return tensor -def all_reduce_sum(tensor: torch.Tensor) -> torch.Tensor: +# def all_reduce_sum(tensor: torch.Tensor, ) -> torch.Tensor: +# """ +# Performs an all-reduce operation to sum the values of the given tensor across all processes. + +# Args: +# tensor (torch.Tensor): The input tensor to be reduced. + +# Returns: +# torch.Tensor: The reduced tensor with the sum of values across all processes. +# """ +# dist.all_reduce(tensor=tensor, op=dist.ReduceOp.SUM) +# return tensor + + +def all_reduce_sum(tensor: torch.Tensor, plugin: Plugin = None) -> torch.Tensor: """ Performs an all-reduce operation to sum the values of the given tensor across all processes. @@ -138,5 +152,9 @@ def all_reduce_sum(tensor: torch.Tensor) -> torch.Tensor: Returns: torch.Tensor: The reduced tensor with the sum of values across all processes. """ - dist.all_reduce(tensor=tensor, op=dist.ReduceOp.SUM) + # All reduce sum across DP group + if plugin is not None: + dist.all_reduce(tensor=tensor, op=dist.ReduceOp.SUM, group=plugin.dp_group) + else: + dist.all_reduce(tensor=tensor, op=dist.ReduceOp.SUM) return tensor diff --git a/applications/ColossalChat/rl_example.py b/applications/ColossalChat/rl_example.py index 6c43ccd19..d4befa20e 100644 --- a/applications/ColossalChat/rl_example.py +++ b/applications/ColossalChat/rl_example.py @@ -60,8 +60,8 @@ if __name__ == "__main__": ray.init(address="local", namespace="ray-example") inference_model_config = dict(path=args.model) - train_model_config = dict(path=args.model, use_flash_attention_2=True, use_cache=False) - generate_config = dict(top_k=50, top_p=0.75, temperature=0.9) + train_model_config = dict(path=args.model, use_flash_attention_2=False, use_cache=False) + generate_config = dict(top_k=-1, top_p=1.0, temperature=1.0) if args.backend == "transformers": inference_model_config.update( @@ -102,6 +102,29 @@ if __name__ == "__main__": ) ) + # Default Settings + # grpo_config = { + # "filter_range": [0.05, 9.0], + # "lr": 1e-6, + # "train_microbatch_size": train_microbatch_size, + # } + + # DAPO variant settings + grpo_config = { + "filter_range": [0.05, 9.0], + "lr": 1e-6, + "train_microbatch_size": args.train_microbatch_size, + "clip_eps_low": 0.2, + "clip_eps_high": 0.28, + "skip_threshold": 20.0, + "beta": 0.0, # no KL penalty + "loss_variation": "token_level", + "soft_over_length_punishment": True, + "max_length": 1024 * 2, + "cache_length": 256, + "filter_truncated_response": True, + } + launch_distributed( num_producers=args.num_inferencer, num_proc_per_producer=1, @@ -118,14 +141,17 @@ if __name__ == "__main__": generate_config=generate_config, num_generations=args.num_generations, train_model_config=train_model_config, - # plugin_config={}, # for zero + grpo_config=grpo_config, plugin_config={ - "pp_size": 2, - "tp_size": 2, - "microbatch_size": args.train_microbatch_size // 2, - "zero_stage": 0, - "max_norm": 1.0, - }, # for pp + "zero_stage": 2, + }, # for zero + # plugin_config={ + # "pp_size": 2, + # "tp_size": 2, + # "microbatch_size": args.train_microbatch_size // 2, + # "zero_stage": 0, + # "max_norm": 1.0, + # }, # for pp inference_backend=args.backend, master_addr="localhost", master_port=29506,