diff --git a/applications/ColossalChat/coati/dataset/loader.py b/applications/ColossalChat/coati/dataset/loader.py index 04093e705..bfee5dce0 100755 --- a/applications/ColossalChat/coati/dataset/loader.py +++ b/applications/ColossalChat/coati/dataset/loader.py @@ -356,6 +356,14 @@ def apply_chat_template_and_mask( truncation: bool = True, ignore_idx: int = -100, ) -> Dict[str, torch.Tensor]: + + system_prompt = "You are a helpful assistant. The assistant first thinks about the reasoning process in the mind and then provides the user with the answer. The reasoning process and answer are enclosed within and tags, respectively, i.e., reasoning process here answer here . Now the user asks you to solve a math problem that involves reasoning. After thinking, when you finally reach a conclusion, clearly output the final answer without explanation within the tags, your final answer should be a integer without unit, currency mark, thousands separator or other text. i.e., 123 .\n" + + system_element = { + "role": "system", + "content": system_prompt, + } + # Format for RL. gt_answer = None if "messages" in chat and "gt_answer" in chat: @@ -365,7 +373,7 @@ def apply_chat_template_and_mask( tokens = [] assistant_mask = [] for i, msg in enumerate(chat): - msg_tokens = tokenizer.apply_chat_template([msg], tokenize=True) + msg_tokens = tokenizer.apply_chat_template([system_element, msg], tokenize=True, add_generation_prompt=True) # remove unexpected bos token if i > 0 and msg_tokens[0] == tokenizer.bos_token_id: msg_tokens = msg_tokens[1:] @@ -378,14 +386,15 @@ def apply_chat_template_and_mask( if max_length is not None: if padding and len(tokens) < max_length: to_pad = max_length - len(tokens) - if tokenizer.padding_side == "right": - tokens.extend([tokenizer.pad_token_id] * to_pad) - assistant_mask.extend([False] * to_pad) - attention_mask.extend([0] * to_pad) - else: - tokens = [tokenizer.pad_token_id] * to_pad + tokens - assistant_mask = [False] * to_pad + assistant_mask - attention_mask = [0] * to_pad + attention_mask + # Left padding for generation. + # if tokenizer.padding_side == "right": + # tokens.extend([tokenizer.pad_token_id] * to_pad) + # assistant_mask.extend([False] * to_pad) + # attention_mask.extend([0] * to_pad) + # else: + tokens = [tokenizer.pad_token_id] * to_pad + tokens + assistant_mask = [False] * to_pad + assistant_mask + attention_mask = [0] * to_pad + attention_mask if truncation and len(tokens) > max_length: tokens = tokens[:max_length] assistant_mask = assistant_mask[:max_length] diff --git a/applications/ColossalChat/coati/distributed/grpo_consumer.py b/applications/ColossalChat/coati/distributed/grpo_consumer.py index d88df2360..2f230f5ed 100644 --- a/applications/ColossalChat/coati/distributed/grpo_consumer.py +++ b/applications/ColossalChat/coati/distributed/grpo_consumer.py @@ -9,7 +9,7 @@ from coati.distributed.loss import PolicyLoss from coati.distributed.reward.reward_fn import math_reward_fn from coati.distributed.reward.verifiable_reward import VerifiableReward from coati.distributed.utils import calc_action_log_probs -from coati.trainer.utils import all_reduce_mean, is_rank_0 +from coati.trainer.utils import all_reduce_mean from transformers import AutoModelForCausalLM, AutoTokenizer from colossalai.nn.optimizer import HybridAdam @@ -77,8 +77,15 @@ class GRPOConsumer(BaseConsumer): ) self.policy_loss_fn = PolicyLoss() - if is_rank_0(): - self.run = wandb.init(project="Colossal-GRPO-Test4") + self.global_step = 0 + if self.rank == 0: + self.wandb_run = wandb.init(project="Colossal-GRPO-Test6", sync_tensorboard=True) + # import os + # import time + + # log_dir = self.wandb_run.dir + # # log_dir = os.path.join(log_dir, time.strftime("%Y-%m-%d_%H:%M:%S", time.localtime())) + # # self.writer = SummaryWriter(log_dir=log_dir) def setup(self): super().setup() @@ -115,10 +122,11 @@ class GRPOConsumer(BaseConsumer): )["logits"] action_log_probs = calc_action_log_probs(policy_model_logits, data["input_ids"], num_action) - reference_model_logits = self.reference_model( - input_ids=data["input_ids"], - attention_mask=data["attention_mask"], - )["logits"] + with torch.no_grad(): + reference_model_logits = self.reference_model( + input_ids=data["input_ids"], + attention_mask=data["attention_mask"], + )["logits"] reference_action_log_probs = calc_action_log_probs(reference_model_logits, data["input_ids"], num_action) # GRPO advantage calculation @@ -126,7 +134,9 @@ class GRPOConsumer(BaseConsumer): action_mask, dim=-1 ) - reward = self.reward_model(data["input_ids"], gt_answer=data["gt_answer"]) + reward = self.reward_model( + data["input_ids"], gt_answer=data["gt_answer"], response_idx=data["response_idx"] + ) reward = kl + reward # [batch_size, num_generations] group_reward = reward.view(-1, self.num_generations) @@ -163,11 +173,19 @@ class GRPOConsumer(BaseConsumer): self.optimizer.step() self.optimizer.zero_grad() loss_scalar = self.accum_loss.item() - if is_rank_0(): + if self.rank == 0: print("Loss:", self.accum_loss.item(), "Reward:", self.accum_reward.item(), "KL:", self.accum_kl.item()) - self.run.log( - {"loss": self.accum_loss.item(), "reward": self.accum_reward.item(), "kl": self.accum_kl.item()} + self.wandb_run.log( + { + "train/loss": self.accum_loss.item(), + "train/reward": self.accum_reward.item(), + "train/kl": self.accum_kl.item(), + } ) + # self.writer.add_scalar("train/loss", self.accum_loss.item(), self.global_step) + # self.writer.add_scalar("train/reward", self.accum_reward.item(), self.global_step) + # self.writer.add_scalar("train/kl", self.accum_kl.item(), self.global_step) + # self.global_step += 1 self.accum_loss.zero_() self.accum_reward.zero_() self.accum_kl.zero_() diff --git a/applications/ColossalChat/coati/distributed/inference_backend.py b/applications/ColossalChat/coati/distributed/inference_backend.py index 210ed5036..bc0ae5c36 100644 --- a/applications/ColossalChat/coati/distributed/inference_backend.py +++ b/applications/ColossalChat/coati/distributed/inference_backend.py @@ -154,6 +154,7 @@ class VLLMInferenceBackend(BaseInferenceBackend): ) FORCE_GENERATE_CONFIG = dict( logprobs=0, + n=4, ) def __init__(self, model_config: Dict[str, Any], generate_config: Dict[str, Any], tokenizer: PreTrainedTokenizer): @@ -166,19 +167,24 @@ class VLLMInferenceBackend(BaseInferenceBackend): generate_config.update(self.FORCE_GENERATE_CONFIG) self.generate_config = SamplingParams(**generate_config) self.tokenizer = tokenizer + self.num_generations = self.FORCE_GENERATE_CONFIG["n"] @torch.no_grad() def generate(self, input_ids: torch.Tensor, attention_mask: torch.Tensor, **kwargs) -> Dict[str, torch.Tensor]: + micro_batch_size = input_ids.size(0) + response_start_idx = input_ids.size(1) outputs = self.llm.generate( prompt_token_ids=input_ids.tolist(), sampling_params=self.generate_config, use_tqdm=False ) out_tokens = [] out_len = [] log_probs = [] + response_idx = [] for out in outputs: for output_i in out.outputs: out_len.append(len(output_i.token_ids)) out_tokens.append(list(output_i.token_ids)) + response_idx.append((response_start_idx, response_start_idx + len(output_i.token_ids) - 1)) assert len(output_i.logprobs) == len(output_i.token_ids) p = [m[t].logprob for m, t in zip(output_i.logprobs, output_i.token_ids)] log_probs.append(p) @@ -195,6 +201,8 @@ class VLLMInferenceBackend(BaseInferenceBackend): out_tokens = torch.tensor(out_tokens) log_probs = torch.tensor(log_probs) + response_idx = torch.tensor(response_idx) + if attention_mask.size(0) != action_mask.size(0): assert action_mask.size(0) % attention_mask.size(0) == 0 num_returns = action_mask.size(0) // attention_mask.size(0) @@ -209,9 +217,14 @@ class VLLMInferenceBackend(BaseInferenceBackend): "attention_mask": attention_mask, "action_log_probs": log_probs, "action_mask": action_mask, + "response_idx": response_idx, } + + data = {k: v.view(micro_batch_size, self.num_generations, v.size(-1)) for k, v in data.items()} + if "gt_answer" in kwargs: - data["gt_answer"] = kwargs["gt_answer"] + # repeat gt_answer for each prompt. + data["gt_answer"] = kwargs["gt_answer"].repeat_interleave(self.num_generations, dim=1) data = {k: v.to(get_current_device()) for k, v in data.items()} return data diff --git a/applications/ColossalChat/coati/distributed/reward/reward_fn.py b/applications/ColossalChat/coati/distributed/reward/reward_fn.py index c7b452c54..c92f822f7 100644 --- a/applications/ColossalChat/coati/distributed/reward/reward_fn.py +++ b/applications/ColossalChat/coati/distributed/reward/reward_fn.py @@ -3,12 +3,14 @@ import torch from .reward_utils import extract_solution, validate_response_structure -def math_reward_fn(input_ids, gt_answer, **kwargs): +def math_reward_fn(input_ids, gt_answer, response_idx, **kwargs): tokenizer = kwargs["tokenizer"] reward = torch.tensor(0.0).to(input_ids.device) + s, e = response_idx[0], response_idx[1] if gt_answer is None: return reward - decoded_final_answer = tokenizer.decode(input_ids, skip_special_tokens=True) + + decoded_final_answer = tokenizer.decode(input_ids[s : e + 1], skip_special_tokens=True) gt_answer = tokenizer.decode(gt_answer.squeeze(0)) final_answer, processed_str = extract_solution(decoded_final_answer) @@ -29,7 +31,7 @@ def gsm8k_reward_fn(input_ids, **kwargs): reward = torch.tensor(0.0).to(input_ids.device) if gt_answer is None: return reward - decoded_final_answer = tokenizer.decode(input_ids[s:e], skip_special_tokens=True) + decoded_final_answer = tokenizer.decode(input_ids[s : e + 1], skip_special_tokens=True) final_answer, processed_str = extract_solution(decoded_final_answer) is_valid = True try: diff --git a/applications/ColossalChat/coati/distributed/reward/verifiable_reward.py b/applications/ColossalChat/coati/distributed/reward/verifiable_reward.py index fe889a7f4..b43ba65c0 100644 --- a/applications/ColossalChat/coati/distributed/reward/verifiable_reward.py +++ b/applications/ColossalChat/coati/distributed/reward/verifiable_reward.py @@ -16,6 +16,7 @@ class VerifiableReward: self, input_ids: torch.LongTensor, gt_answer: List[torch.Tensor] = None, + response_idx: List[torch.Tensor] = None, ) -> torch.Tensor: # Get batch size bs = input_ids.size(0) @@ -30,6 +31,7 @@ class VerifiableReward: reward_fn( input_ids[i], gt_answer=gt_answer[i], + response_idx=response_idx[i], **self.kwargs, ) for i in range(bs) diff --git a/applications/ColossalChat/rl_example.py b/applications/ColossalChat/rl_example.py index a6f82b3be..1b5c18486 100644 --- a/applications/ColossalChat/rl_example.py +++ b/applications/ColossalChat/rl_example.py @@ -51,13 +51,17 @@ if __name__ == "__main__": elif args.backend == "vllm": inference_model_config.update( dict( - gpu_memory_utilization=0.6, + gpu_memory_utilization=0.7, ) ) generate_config.update( dict( - max_tokens=256, + max_tokens=2048, ignore_eos=True, + include_stop_str_in_output=True, + stop=[""], + temperature=0.2, + top_p=0.95, ) ) else: