This commit is contained in:
Tong Li 2025-02-28 10:16:42 +08:00
parent f736d747e3
commit 070907dd7f
6 changed files with 74 additions and 26 deletions

View File

@ -356,6 +356,14 @@ def apply_chat_template_and_mask(
truncation: bool = True, truncation: bool = True,
ignore_idx: int = -100, ignore_idx: int = -100,
) -> Dict[str, torch.Tensor]: ) -> Dict[str, torch.Tensor]:
system_prompt = "You are a helpful assistant. The assistant first thinks about the reasoning process in the mind and then provides the user with the answer. The reasoning process and answer are enclosed within <think> </think> and<answer> </answer> tags, respectively, i.e., <think> reasoning process here </think><answer> answer here </answer>. Now the user asks you to solve a math problem that involves reasoning. After thinking, when you finally reach a conclusion, clearly output the final answer without explanation within the <answer> </answer> tags, your final answer should be a integer without unit, currency mark, thousands separator or other text. i.e., <answer> 123 </answer>.\n"
system_element = {
"role": "system",
"content": system_prompt,
}
# Format for RL. # Format for RL.
gt_answer = None gt_answer = None
if "messages" in chat and "gt_answer" in chat: if "messages" in chat and "gt_answer" in chat:
@ -365,7 +373,7 @@ def apply_chat_template_and_mask(
tokens = [] tokens = []
assistant_mask = [] assistant_mask = []
for i, msg in enumerate(chat): for i, msg in enumerate(chat):
msg_tokens = tokenizer.apply_chat_template([msg], tokenize=True) msg_tokens = tokenizer.apply_chat_template([system_element, msg], tokenize=True, add_generation_prompt=True)
# remove unexpected bos token # remove unexpected bos token
if i > 0 and msg_tokens[0] == tokenizer.bos_token_id: if i > 0 and msg_tokens[0] == tokenizer.bos_token_id:
msg_tokens = msg_tokens[1:] msg_tokens = msg_tokens[1:]
@ -378,14 +386,15 @@ def apply_chat_template_and_mask(
if max_length is not None: if max_length is not None:
if padding and len(tokens) < max_length: if padding and len(tokens) < max_length:
to_pad = max_length - len(tokens) to_pad = max_length - len(tokens)
if tokenizer.padding_side == "right": # Left padding for generation.
tokens.extend([tokenizer.pad_token_id] * to_pad) # if tokenizer.padding_side == "right":
assistant_mask.extend([False] * to_pad) # tokens.extend([tokenizer.pad_token_id] * to_pad)
attention_mask.extend([0] * to_pad) # assistant_mask.extend([False] * to_pad)
else: # attention_mask.extend([0] * to_pad)
tokens = [tokenizer.pad_token_id] * to_pad + tokens # else:
assistant_mask = [False] * to_pad + assistant_mask tokens = [tokenizer.pad_token_id] * to_pad + tokens
attention_mask = [0] * to_pad + attention_mask assistant_mask = [False] * to_pad + assistant_mask
attention_mask = [0] * to_pad + attention_mask
if truncation and len(tokens) > max_length: if truncation and len(tokens) > max_length:
tokens = tokens[:max_length] tokens = tokens[:max_length]
assistant_mask = assistant_mask[:max_length] assistant_mask = assistant_mask[:max_length]

View File

@ -9,7 +9,7 @@ from coati.distributed.loss import PolicyLoss
from coati.distributed.reward.reward_fn import math_reward_fn from coati.distributed.reward.reward_fn import math_reward_fn
from coati.distributed.reward.verifiable_reward import VerifiableReward from coati.distributed.reward.verifiable_reward import VerifiableReward
from coati.distributed.utils import calc_action_log_probs from coati.distributed.utils import calc_action_log_probs
from coati.trainer.utils import all_reduce_mean, is_rank_0 from coati.trainer.utils import all_reduce_mean
from transformers import AutoModelForCausalLM, AutoTokenizer from transformers import AutoModelForCausalLM, AutoTokenizer
from colossalai.nn.optimizer import HybridAdam from colossalai.nn.optimizer import HybridAdam
@ -77,8 +77,15 @@ class GRPOConsumer(BaseConsumer):
) )
self.policy_loss_fn = PolicyLoss() self.policy_loss_fn = PolicyLoss()
if is_rank_0(): self.global_step = 0
self.run = wandb.init(project="Colossal-GRPO-Test4") if self.rank == 0:
self.wandb_run = wandb.init(project="Colossal-GRPO-Test6", sync_tensorboard=True)
# import os
# import time
# log_dir = self.wandb_run.dir
# # log_dir = os.path.join(log_dir, time.strftime("%Y-%m-%d_%H:%M:%S", time.localtime()))
# # self.writer = SummaryWriter(log_dir=log_dir)
def setup(self): def setup(self):
super().setup() super().setup()
@ -115,10 +122,11 @@ class GRPOConsumer(BaseConsumer):
)["logits"] )["logits"]
action_log_probs = calc_action_log_probs(policy_model_logits, data["input_ids"], num_action) action_log_probs = calc_action_log_probs(policy_model_logits, data["input_ids"], num_action)
reference_model_logits = self.reference_model( with torch.no_grad():
input_ids=data["input_ids"], reference_model_logits = self.reference_model(
attention_mask=data["attention_mask"], input_ids=data["input_ids"],
)["logits"] attention_mask=data["attention_mask"],
)["logits"]
reference_action_log_probs = calc_action_log_probs(reference_model_logits, data["input_ids"], num_action) reference_action_log_probs = calc_action_log_probs(reference_model_logits, data["input_ids"], num_action)
# GRPO advantage calculation # GRPO advantage calculation
@ -126,7 +134,9 @@ class GRPOConsumer(BaseConsumer):
action_mask, dim=-1 action_mask, dim=-1
) )
reward = self.reward_model(data["input_ids"], gt_answer=data["gt_answer"]) reward = self.reward_model(
data["input_ids"], gt_answer=data["gt_answer"], response_idx=data["response_idx"]
)
reward = kl + reward reward = kl + reward
# [batch_size, num_generations] # [batch_size, num_generations]
group_reward = reward.view(-1, self.num_generations) group_reward = reward.view(-1, self.num_generations)
@ -163,11 +173,19 @@ class GRPOConsumer(BaseConsumer):
self.optimizer.step() self.optimizer.step()
self.optimizer.zero_grad() self.optimizer.zero_grad()
loss_scalar = self.accum_loss.item() loss_scalar = self.accum_loss.item()
if is_rank_0(): if self.rank == 0:
print("Loss:", self.accum_loss.item(), "Reward:", self.accum_reward.item(), "KL:", self.accum_kl.item()) print("Loss:", self.accum_loss.item(), "Reward:", self.accum_reward.item(), "KL:", self.accum_kl.item())
self.run.log( self.wandb_run.log(
{"loss": self.accum_loss.item(), "reward": self.accum_reward.item(), "kl": self.accum_kl.item()} {
"train/loss": self.accum_loss.item(),
"train/reward": self.accum_reward.item(),
"train/kl": self.accum_kl.item(),
}
) )
# self.writer.add_scalar("train/loss", self.accum_loss.item(), self.global_step)
# self.writer.add_scalar("train/reward", self.accum_reward.item(), self.global_step)
# self.writer.add_scalar("train/kl", self.accum_kl.item(), self.global_step)
# self.global_step += 1
self.accum_loss.zero_() self.accum_loss.zero_()
self.accum_reward.zero_() self.accum_reward.zero_()
self.accum_kl.zero_() self.accum_kl.zero_()

View File

@ -154,6 +154,7 @@ class VLLMInferenceBackend(BaseInferenceBackend):
) )
FORCE_GENERATE_CONFIG = dict( FORCE_GENERATE_CONFIG = dict(
logprobs=0, logprobs=0,
n=4,
) )
def __init__(self, model_config: Dict[str, Any], generate_config: Dict[str, Any], tokenizer: PreTrainedTokenizer): def __init__(self, model_config: Dict[str, Any], generate_config: Dict[str, Any], tokenizer: PreTrainedTokenizer):
@ -166,19 +167,24 @@ class VLLMInferenceBackend(BaseInferenceBackend):
generate_config.update(self.FORCE_GENERATE_CONFIG) generate_config.update(self.FORCE_GENERATE_CONFIG)
self.generate_config = SamplingParams(**generate_config) self.generate_config = SamplingParams(**generate_config)
self.tokenizer = tokenizer self.tokenizer = tokenizer
self.num_generations = self.FORCE_GENERATE_CONFIG["n"]
@torch.no_grad() @torch.no_grad()
def generate(self, input_ids: torch.Tensor, attention_mask: torch.Tensor, **kwargs) -> Dict[str, torch.Tensor]: def generate(self, input_ids: torch.Tensor, attention_mask: torch.Tensor, **kwargs) -> Dict[str, torch.Tensor]:
micro_batch_size = input_ids.size(0)
response_start_idx = input_ids.size(1)
outputs = self.llm.generate( outputs = self.llm.generate(
prompt_token_ids=input_ids.tolist(), sampling_params=self.generate_config, use_tqdm=False prompt_token_ids=input_ids.tolist(), sampling_params=self.generate_config, use_tqdm=False
) )
out_tokens = [] out_tokens = []
out_len = [] out_len = []
log_probs = [] log_probs = []
response_idx = []
for out in outputs: for out in outputs:
for output_i in out.outputs: for output_i in out.outputs:
out_len.append(len(output_i.token_ids)) out_len.append(len(output_i.token_ids))
out_tokens.append(list(output_i.token_ids)) out_tokens.append(list(output_i.token_ids))
response_idx.append((response_start_idx, response_start_idx + len(output_i.token_ids) - 1))
assert len(output_i.logprobs) == len(output_i.token_ids) assert len(output_i.logprobs) == len(output_i.token_ids)
p = [m[t].logprob for m, t in zip(output_i.logprobs, output_i.token_ids)] p = [m[t].logprob for m, t in zip(output_i.logprobs, output_i.token_ids)]
log_probs.append(p) log_probs.append(p)
@ -195,6 +201,8 @@ class VLLMInferenceBackend(BaseInferenceBackend):
out_tokens = torch.tensor(out_tokens) out_tokens = torch.tensor(out_tokens)
log_probs = torch.tensor(log_probs) log_probs = torch.tensor(log_probs)
response_idx = torch.tensor(response_idx)
if attention_mask.size(0) != action_mask.size(0): if attention_mask.size(0) != action_mask.size(0):
assert action_mask.size(0) % attention_mask.size(0) == 0 assert action_mask.size(0) % attention_mask.size(0) == 0
num_returns = action_mask.size(0) // attention_mask.size(0) num_returns = action_mask.size(0) // attention_mask.size(0)
@ -209,9 +217,14 @@ class VLLMInferenceBackend(BaseInferenceBackend):
"attention_mask": attention_mask, "attention_mask": attention_mask,
"action_log_probs": log_probs, "action_log_probs": log_probs,
"action_mask": action_mask, "action_mask": action_mask,
"response_idx": response_idx,
} }
data = {k: v.view(micro_batch_size, self.num_generations, v.size(-1)) for k, v in data.items()}
if "gt_answer" in kwargs: if "gt_answer" in kwargs:
data["gt_answer"] = kwargs["gt_answer"] # repeat gt_answer for each prompt.
data["gt_answer"] = kwargs["gt_answer"].repeat_interleave(self.num_generations, dim=1)
data = {k: v.to(get_current_device()) for k, v in data.items()} data = {k: v.to(get_current_device()) for k, v in data.items()}
return data return data

View File

@ -3,12 +3,14 @@ import torch
from .reward_utils import extract_solution, validate_response_structure from .reward_utils import extract_solution, validate_response_structure
def math_reward_fn(input_ids, gt_answer, **kwargs): def math_reward_fn(input_ids, gt_answer, response_idx, **kwargs):
tokenizer = kwargs["tokenizer"] tokenizer = kwargs["tokenizer"]
reward = torch.tensor(0.0).to(input_ids.device) reward = torch.tensor(0.0).to(input_ids.device)
s, e = response_idx[0], response_idx[1]
if gt_answer is None: if gt_answer is None:
return reward return reward
decoded_final_answer = tokenizer.decode(input_ids, skip_special_tokens=True)
decoded_final_answer = tokenizer.decode(input_ids[s : e + 1], skip_special_tokens=True)
gt_answer = tokenizer.decode(gt_answer.squeeze(0)) gt_answer = tokenizer.decode(gt_answer.squeeze(0))
final_answer, processed_str = extract_solution(decoded_final_answer) final_answer, processed_str = extract_solution(decoded_final_answer)
@ -29,7 +31,7 @@ def gsm8k_reward_fn(input_ids, **kwargs):
reward = torch.tensor(0.0).to(input_ids.device) reward = torch.tensor(0.0).to(input_ids.device)
if gt_answer is None: if gt_answer is None:
return reward return reward
decoded_final_answer = tokenizer.decode(input_ids[s:e], skip_special_tokens=True) decoded_final_answer = tokenizer.decode(input_ids[s : e + 1], skip_special_tokens=True)
final_answer, processed_str = extract_solution(decoded_final_answer) final_answer, processed_str = extract_solution(decoded_final_answer)
is_valid = True is_valid = True
try: try:

View File

@ -16,6 +16,7 @@ class VerifiableReward:
self, self,
input_ids: torch.LongTensor, input_ids: torch.LongTensor,
gt_answer: List[torch.Tensor] = None, gt_answer: List[torch.Tensor] = None,
response_idx: List[torch.Tensor] = None,
) -> torch.Tensor: ) -> torch.Tensor:
# Get batch size # Get batch size
bs = input_ids.size(0) bs = input_ids.size(0)
@ -30,6 +31,7 @@ class VerifiableReward:
reward_fn( reward_fn(
input_ids[i], input_ids[i],
gt_answer=gt_answer[i], gt_answer=gt_answer[i],
response_idx=response_idx[i],
**self.kwargs, **self.kwargs,
) )
for i in range(bs) for i in range(bs)

View File

@ -51,13 +51,17 @@ if __name__ == "__main__":
elif args.backend == "vllm": elif args.backend == "vllm":
inference_model_config.update( inference_model_config.update(
dict( dict(
gpu_memory_utilization=0.6, gpu_memory_utilization=0.7,
) )
) )
generate_config.update( generate_config.update(
dict( dict(
max_tokens=256, max_tokens=2048,
ignore_eos=True, ignore_eos=True,
include_stop_str_in_output=True,
stop=["</answer>"],
temperature=0.2,
top_p=0.95,
) )
) )
else: else: