mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-06-19 20:23:41 +00:00
polish
This commit is contained in:
parent
f736d747e3
commit
070907dd7f
@ -356,6 +356,14 @@ def apply_chat_template_and_mask(
|
||||
truncation: bool = True,
|
||||
ignore_idx: int = -100,
|
||||
) -> Dict[str, torch.Tensor]:
|
||||
|
||||
system_prompt = "You are a helpful assistant. The assistant first thinks about the reasoning process in the mind and then provides the user with the answer. The reasoning process and answer are enclosed within <think> </think> and<answer> </answer> tags, respectively, i.e., <think> reasoning process here </think><answer> answer here </answer>. Now the user asks you to solve a math problem that involves reasoning. After thinking, when you finally reach a conclusion, clearly output the final answer without explanation within the <answer> </answer> tags, your final answer should be a integer without unit, currency mark, thousands separator or other text. i.e., <answer> 123 </answer>.\n"
|
||||
|
||||
system_element = {
|
||||
"role": "system",
|
||||
"content": system_prompt,
|
||||
}
|
||||
|
||||
# Format for RL.
|
||||
gt_answer = None
|
||||
if "messages" in chat and "gt_answer" in chat:
|
||||
@ -365,7 +373,7 @@ def apply_chat_template_and_mask(
|
||||
tokens = []
|
||||
assistant_mask = []
|
||||
for i, msg in enumerate(chat):
|
||||
msg_tokens = tokenizer.apply_chat_template([msg], tokenize=True)
|
||||
msg_tokens = tokenizer.apply_chat_template([system_element, msg], tokenize=True, add_generation_prompt=True)
|
||||
# remove unexpected bos token
|
||||
if i > 0 and msg_tokens[0] == tokenizer.bos_token_id:
|
||||
msg_tokens = msg_tokens[1:]
|
||||
@ -378,14 +386,15 @@ def apply_chat_template_and_mask(
|
||||
if max_length is not None:
|
||||
if padding and len(tokens) < max_length:
|
||||
to_pad = max_length - len(tokens)
|
||||
if tokenizer.padding_side == "right":
|
||||
tokens.extend([tokenizer.pad_token_id] * to_pad)
|
||||
assistant_mask.extend([False] * to_pad)
|
||||
attention_mask.extend([0] * to_pad)
|
||||
else:
|
||||
tokens = [tokenizer.pad_token_id] * to_pad + tokens
|
||||
assistant_mask = [False] * to_pad + assistant_mask
|
||||
attention_mask = [0] * to_pad + attention_mask
|
||||
# Left padding for generation.
|
||||
# if tokenizer.padding_side == "right":
|
||||
# tokens.extend([tokenizer.pad_token_id] * to_pad)
|
||||
# assistant_mask.extend([False] * to_pad)
|
||||
# attention_mask.extend([0] * to_pad)
|
||||
# else:
|
||||
tokens = [tokenizer.pad_token_id] * to_pad + tokens
|
||||
assistant_mask = [False] * to_pad + assistant_mask
|
||||
attention_mask = [0] * to_pad + attention_mask
|
||||
if truncation and len(tokens) > max_length:
|
||||
tokens = tokens[:max_length]
|
||||
assistant_mask = assistant_mask[:max_length]
|
||||
|
@ -9,7 +9,7 @@ from coati.distributed.loss import PolicyLoss
|
||||
from coati.distributed.reward.reward_fn import math_reward_fn
|
||||
from coati.distributed.reward.verifiable_reward import VerifiableReward
|
||||
from coati.distributed.utils import calc_action_log_probs
|
||||
from coati.trainer.utils import all_reduce_mean, is_rank_0
|
||||
from coati.trainer.utils import all_reduce_mean
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
|
||||
from colossalai.nn.optimizer import HybridAdam
|
||||
@ -77,8 +77,15 @@ class GRPOConsumer(BaseConsumer):
|
||||
)
|
||||
|
||||
self.policy_loss_fn = PolicyLoss()
|
||||
if is_rank_0():
|
||||
self.run = wandb.init(project="Colossal-GRPO-Test4")
|
||||
self.global_step = 0
|
||||
if self.rank == 0:
|
||||
self.wandb_run = wandb.init(project="Colossal-GRPO-Test6", sync_tensorboard=True)
|
||||
# import os
|
||||
# import time
|
||||
|
||||
# log_dir = self.wandb_run.dir
|
||||
# # log_dir = os.path.join(log_dir, time.strftime("%Y-%m-%d_%H:%M:%S", time.localtime()))
|
||||
# # self.writer = SummaryWriter(log_dir=log_dir)
|
||||
|
||||
def setup(self):
|
||||
super().setup()
|
||||
@ -115,10 +122,11 @@ class GRPOConsumer(BaseConsumer):
|
||||
)["logits"]
|
||||
action_log_probs = calc_action_log_probs(policy_model_logits, data["input_ids"], num_action)
|
||||
|
||||
reference_model_logits = self.reference_model(
|
||||
input_ids=data["input_ids"],
|
||||
attention_mask=data["attention_mask"],
|
||||
)["logits"]
|
||||
with torch.no_grad():
|
||||
reference_model_logits = self.reference_model(
|
||||
input_ids=data["input_ids"],
|
||||
attention_mask=data["attention_mask"],
|
||||
)["logits"]
|
||||
reference_action_log_probs = calc_action_log_probs(reference_model_logits, data["input_ids"], num_action)
|
||||
|
||||
# GRPO advantage calculation
|
||||
@ -126,7 +134,9 @@ class GRPOConsumer(BaseConsumer):
|
||||
action_mask, dim=-1
|
||||
)
|
||||
|
||||
reward = self.reward_model(data["input_ids"], gt_answer=data["gt_answer"])
|
||||
reward = self.reward_model(
|
||||
data["input_ids"], gt_answer=data["gt_answer"], response_idx=data["response_idx"]
|
||||
)
|
||||
reward = kl + reward
|
||||
# [batch_size, num_generations]
|
||||
group_reward = reward.view(-1, self.num_generations)
|
||||
@ -163,11 +173,19 @@ class GRPOConsumer(BaseConsumer):
|
||||
self.optimizer.step()
|
||||
self.optimizer.zero_grad()
|
||||
loss_scalar = self.accum_loss.item()
|
||||
if is_rank_0():
|
||||
if self.rank == 0:
|
||||
print("Loss:", self.accum_loss.item(), "Reward:", self.accum_reward.item(), "KL:", self.accum_kl.item())
|
||||
self.run.log(
|
||||
{"loss": self.accum_loss.item(), "reward": self.accum_reward.item(), "kl": self.accum_kl.item()}
|
||||
self.wandb_run.log(
|
||||
{
|
||||
"train/loss": self.accum_loss.item(),
|
||||
"train/reward": self.accum_reward.item(),
|
||||
"train/kl": self.accum_kl.item(),
|
||||
}
|
||||
)
|
||||
# self.writer.add_scalar("train/loss", self.accum_loss.item(), self.global_step)
|
||||
# self.writer.add_scalar("train/reward", self.accum_reward.item(), self.global_step)
|
||||
# self.writer.add_scalar("train/kl", self.accum_kl.item(), self.global_step)
|
||||
# self.global_step += 1
|
||||
self.accum_loss.zero_()
|
||||
self.accum_reward.zero_()
|
||||
self.accum_kl.zero_()
|
||||
|
@ -154,6 +154,7 @@ class VLLMInferenceBackend(BaseInferenceBackend):
|
||||
)
|
||||
FORCE_GENERATE_CONFIG = dict(
|
||||
logprobs=0,
|
||||
n=4,
|
||||
)
|
||||
|
||||
def __init__(self, model_config: Dict[str, Any], generate_config: Dict[str, Any], tokenizer: PreTrainedTokenizer):
|
||||
@ -166,19 +167,24 @@ class VLLMInferenceBackend(BaseInferenceBackend):
|
||||
generate_config.update(self.FORCE_GENERATE_CONFIG)
|
||||
self.generate_config = SamplingParams(**generate_config)
|
||||
self.tokenizer = tokenizer
|
||||
self.num_generations = self.FORCE_GENERATE_CONFIG["n"]
|
||||
|
||||
@torch.no_grad()
|
||||
def generate(self, input_ids: torch.Tensor, attention_mask: torch.Tensor, **kwargs) -> Dict[str, torch.Tensor]:
|
||||
micro_batch_size = input_ids.size(0)
|
||||
response_start_idx = input_ids.size(1)
|
||||
outputs = self.llm.generate(
|
||||
prompt_token_ids=input_ids.tolist(), sampling_params=self.generate_config, use_tqdm=False
|
||||
)
|
||||
out_tokens = []
|
||||
out_len = []
|
||||
log_probs = []
|
||||
response_idx = []
|
||||
for out in outputs:
|
||||
for output_i in out.outputs:
|
||||
out_len.append(len(output_i.token_ids))
|
||||
out_tokens.append(list(output_i.token_ids))
|
||||
response_idx.append((response_start_idx, response_start_idx + len(output_i.token_ids) - 1))
|
||||
assert len(output_i.logprobs) == len(output_i.token_ids)
|
||||
p = [m[t].logprob for m, t in zip(output_i.logprobs, output_i.token_ids)]
|
||||
log_probs.append(p)
|
||||
@ -195,6 +201,8 @@ class VLLMInferenceBackend(BaseInferenceBackend):
|
||||
|
||||
out_tokens = torch.tensor(out_tokens)
|
||||
log_probs = torch.tensor(log_probs)
|
||||
response_idx = torch.tensor(response_idx)
|
||||
|
||||
if attention_mask.size(0) != action_mask.size(0):
|
||||
assert action_mask.size(0) % attention_mask.size(0) == 0
|
||||
num_returns = action_mask.size(0) // attention_mask.size(0)
|
||||
@ -209,9 +217,14 @@ class VLLMInferenceBackend(BaseInferenceBackend):
|
||||
"attention_mask": attention_mask,
|
||||
"action_log_probs": log_probs,
|
||||
"action_mask": action_mask,
|
||||
"response_idx": response_idx,
|
||||
}
|
||||
|
||||
data = {k: v.view(micro_batch_size, self.num_generations, v.size(-1)) for k, v in data.items()}
|
||||
|
||||
if "gt_answer" in kwargs:
|
||||
data["gt_answer"] = kwargs["gt_answer"]
|
||||
# repeat gt_answer for each prompt.
|
||||
data["gt_answer"] = kwargs["gt_answer"].repeat_interleave(self.num_generations, dim=1)
|
||||
data = {k: v.to(get_current_device()) for k, v in data.items()}
|
||||
return data
|
||||
|
||||
|
@ -3,12 +3,14 @@ import torch
|
||||
from .reward_utils import extract_solution, validate_response_structure
|
||||
|
||||
|
||||
def math_reward_fn(input_ids, gt_answer, **kwargs):
|
||||
def math_reward_fn(input_ids, gt_answer, response_idx, **kwargs):
|
||||
tokenizer = kwargs["tokenizer"]
|
||||
reward = torch.tensor(0.0).to(input_ids.device)
|
||||
s, e = response_idx[0], response_idx[1]
|
||||
if gt_answer is None:
|
||||
return reward
|
||||
decoded_final_answer = tokenizer.decode(input_ids, skip_special_tokens=True)
|
||||
|
||||
decoded_final_answer = tokenizer.decode(input_ids[s : e + 1], skip_special_tokens=True)
|
||||
gt_answer = tokenizer.decode(gt_answer.squeeze(0))
|
||||
final_answer, processed_str = extract_solution(decoded_final_answer)
|
||||
|
||||
@ -29,7 +31,7 @@ def gsm8k_reward_fn(input_ids, **kwargs):
|
||||
reward = torch.tensor(0.0).to(input_ids.device)
|
||||
if gt_answer is None:
|
||||
return reward
|
||||
decoded_final_answer = tokenizer.decode(input_ids[s:e], skip_special_tokens=True)
|
||||
decoded_final_answer = tokenizer.decode(input_ids[s : e + 1], skip_special_tokens=True)
|
||||
final_answer, processed_str = extract_solution(decoded_final_answer)
|
||||
is_valid = True
|
||||
try:
|
||||
|
@ -16,6 +16,7 @@ class VerifiableReward:
|
||||
self,
|
||||
input_ids: torch.LongTensor,
|
||||
gt_answer: List[torch.Tensor] = None,
|
||||
response_idx: List[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
# Get batch size
|
||||
bs = input_ids.size(0)
|
||||
@ -30,6 +31,7 @@ class VerifiableReward:
|
||||
reward_fn(
|
||||
input_ids[i],
|
||||
gt_answer=gt_answer[i],
|
||||
response_idx=response_idx[i],
|
||||
**self.kwargs,
|
||||
)
|
||||
for i in range(bs)
|
||||
|
@ -51,13 +51,17 @@ if __name__ == "__main__":
|
||||
elif args.backend == "vllm":
|
||||
inference_model_config.update(
|
||||
dict(
|
||||
gpu_memory_utilization=0.6,
|
||||
gpu_memory_utilization=0.7,
|
||||
)
|
||||
)
|
||||
generate_config.update(
|
||||
dict(
|
||||
max_tokens=256,
|
||||
max_tokens=2048,
|
||||
ignore_eos=True,
|
||||
include_stop_str_in_output=True,
|
||||
stop=["</answer>"],
|
||||
temperature=0.2,
|
||||
top_p=0.95,
|
||||
)
|
||||
)
|
||||
else:
|
||||
|
Loading…
Reference in New Issue
Block a user