mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-06-21 05:04:47 +00:00
Merge pull request #6314 from hpcaitech/grpo-reward-dev
[feat] upgrade reward functions
This commit is contained in:
commit
bcf2459db5
1
.gitignore
vendored
1
.gitignore
vendored
@ -166,3 +166,4 @@ applications/ColossalChat/tests/logs
|
|||||||
applications/ColossalChat/wandb
|
applications/ColossalChat/wandb
|
||||||
applications/ColossalChat/model
|
applications/ColossalChat/model
|
||||||
applications/ColossalChat/eval
|
applications/ColossalChat/eval
|
||||||
|
applications/ColossalChat/rollouts
|
||||||
|
@ -120,14 +120,20 @@ class GRPOConsumer(BaseConsumer):
|
|||||||
"either max_tokens (vllm) or max_new_tokens (transformers) must be set in generate_config."
|
"either max_tokens (vllm) or max_new_tokens (transformers) must be set in generate_config."
|
||||||
)
|
)
|
||||||
# Initialize verifiable reward.
|
# Initialize verifiable reward.
|
||||||
response_format_tags = {
|
response_format_tags = (
|
||||||
|
{
|
||||||
"think_start": {"text": "<think>", "num_occur": 1},
|
"think_start": {"text": "<think>", "num_occur": 1},
|
||||||
"think_end": {"text": "</think>", "num_occur": 1},
|
"think_end": {"text": "</think>", "num_occur": 1},
|
||||||
"answer_start": {"text": "<answer>", "num_occur": 1},
|
"answer_start": {"text": "<answer>", "num_occur": 1},
|
||||||
"answer_end": {"text": "</answer>", "num_occur": 1},
|
"answer_end": {"text": "</answer>", "num_occur": 1},
|
||||||
}
|
}
|
||||||
|
if grpo_config.get("reward_fn_type") == "think_answer_tags"
|
||||||
|
else None
|
||||||
|
)
|
||||||
reward_model_kwargs = {
|
reward_model_kwargs = {
|
||||||
k: v for k, v in grpo_config.items() if k in ["soft_over_length_punishment", "max_length", "cache_length"]
|
k: v
|
||||||
|
for k, v in grpo_config.items()
|
||||||
|
if k in ["soft_over_length_punishment", "max_new_tokens", "cache_length"]
|
||||||
}
|
}
|
||||||
self.reward_model = VerifiableReward(
|
self.reward_model = VerifiableReward(
|
||||||
reward_fns=[
|
reward_fns=[
|
||||||
|
@ -55,6 +55,8 @@ def launch_distributed(
|
|||||||
eval_interval: int = 100,
|
eval_interval: int = 100,
|
||||||
eval_save_dir: Optional[str] = None,
|
eval_save_dir: Optional[str] = None,
|
||||||
eval_generation_config: Optional[Dict[str, Any]] = None,
|
eval_generation_config: Optional[Dict[str, Any]] = None,
|
||||||
|
log_rollout_interval: int = 20,
|
||||||
|
rollout_save_dir: str = "./rollout",
|
||||||
):
|
):
|
||||||
if core_algo not in ALGO_MAP:
|
if core_algo not in ALGO_MAP:
|
||||||
raise NotImplementedError(f"{core_algo} is not supported yet.")
|
raise NotImplementedError(f"{core_algo} is not supported yet.")
|
||||||
@ -72,6 +74,10 @@ def launch_distributed(
|
|||||||
|
|
||||||
run_name = f"{inference_backend}_bs_{train_batch_size * train_dp_size}_temp_{generate_config['temperature']:.01f}_top_p_{generate_config['top_p']:.02f}"
|
run_name = f"{inference_backend}_bs_{train_batch_size * train_dp_size}_temp_{generate_config['temperature']:.01f}_top_p_{generate_config['top_p']:.02f}"
|
||||||
wandb_group_name = str(uuid.uuid4())
|
wandb_group_name = str(uuid.uuid4())
|
||||||
|
rollout_log_file = os.path.join(
|
||||||
|
rollout_save_dir,
|
||||||
|
f"{project_name.replace(' ','_')}_run_{wandb_group_name}.jsonl",
|
||||||
|
)
|
||||||
|
|
||||||
procs = []
|
procs = []
|
||||||
for i in range(num_producers):
|
for i in range(num_producers):
|
||||||
@ -98,6 +104,8 @@ def launch_distributed(
|
|||||||
project_name=project_name,
|
project_name=project_name,
|
||||||
run_name=run_name,
|
run_name=run_name,
|
||||||
wandb_group_name=wandb_group_name,
|
wandb_group_name=wandb_group_name,
|
||||||
|
log_rollout_interval=log_rollout_interval,
|
||||||
|
rollout_log_file=rollout_log_file,
|
||||||
)
|
)
|
||||||
procs.append(producer)
|
procs.append(producer)
|
||||||
generate_config_consumer = copy.deepcopy(generate_config)
|
generate_config_consumer = copy.deepcopy(generate_config)
|
||||||
|
@ -1,4 +1,5 @@
|
|||||||
import copy
|
import copy
|
||||||
|
import json
|
||||||
import os
|
import os
|
||||||
from typing import Any, Dict, Optional
|
from typing import Any, Dict, Optional
|
||||||
|
|
||||||
@ -49,7 +50,8 @@ class BaseProducer:
|
|||||||
project_name: str = None,
|
project_name: str = None,
|
||||||
run_name: str = None,
|
run_name: str = None,
|
||||||
wandb_group_name: str = None,
|
wandb_group_name: str = None,
|
||||||
wandb_log_rollout_interval: int = 20,
|
log_rollout_interval: int = 20,
|
||||||
|
rollout_log_file: str = "./rollout_log.jsonl",
|
||||||
):
|
):
|
||||||
self.producer_idx = producer_idx
|
self.producer_idx = producer_idx
|
||||||
self.num_producers = num_producers
|
self.num_producers = num_producers
|
||||||
@ -70,9 +72,16 @@ class BaseProducer:
|
|||||||
self.eval_save_dir = eval_save_dir
|
self.eval_save_dir = eval_save_dir
|
||||||
self.consumer_global_step = 0
|
self.consumer_global_step = 0
|
||||||
self.eval_mode = False
|
self.eval_mode = False
|
||||||
self.wandb_rollout_data = []
|
self.log_rollout_interval = log_rollout_interval
|
||||||
self.wandb_log_rollout_interval = wandb_log_rollout_interval
|
|
||||||
self.latest_rollout_log_step = -1
|
self.latest_rollout_log_step = -1
|
||||||
|
if producer_idx == 0:
|
||||||
|
if os.path.exists(rollout_log_file):
|
||||||
|
raise ValueError(
|
||||||
|
f"Rollout log file {rollout_log_file} already exists. Please delete it or change the name."
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
os.makedirs(os.path.dirname(rollout_log_file), exist_ok=True)
|
||||||
|
self.rollout_log_file = open(rollout_log_file, "w", encoding="utf8")
|
||||||
if self.producer_idx == 0:
|
if self.producer_idx == 0:
|
||||||
self.wandb_run = wandb.init(
|
self.wandb_run = wandb.init(
|
||||||
project=project_name,
|
project=project_name,
|
||||||
@ -320,6 +329,8 @@ class SimpleProducer(BaseProducer):
|
|||||||
project_name: str = None,
|
project_name: str = None,
|
||||||
run_name: str = None,
|
run_name: str = None,
|
||||||
wandb_group_name: str = None,
|
wandb_group_name: str = None,
|
||||||
|
log_rollout_interval: int = 20,
|
||||||
|
rollout_log_file: str = "./rollout_log.jsonl",
|
||||||
):
|
):
|
||||||
super().__init__(
|
super().__init__(
|
||||||
producer_idx,
|
producer_idx,
|
||||||
@ -342,6 +353,8 @@ class SimpleProducer(BaseProducer):
|
|||||||
project_name=project_name,
|
project_name=project_name,
|
||||||
run_name=run_name,
|
run_name=run_name,
|
||||||
wandb_group_name=wandb_group_name,
|
wandb_group_name=wandb_group_name,
|
||||||
|
log_rollout_interval=log_rollout_interval,
|
||||||
|
rollout_log_file=rollout_log_file,
|
||||||
)
|
)
|
||||||
self.model = self.backend_cls(model_config, generate_config, self.tokenizer, num_generations)
|
self.model = self.backend_cls(model_config, generate_config, self.tokenizer, num_generations)
|
||||||
self.eval_generation_config = copy.deepcopy(self.model.generate_config)
|
self.eval_generation_config = copy.deepcopy(self.model.generate_config)
|
||||||
@ -353,26 +366,31 @@ class SimpleProducer(BaseProducer):
|
|||||||
def rollout(self, input_ids, attention_mask, **kwargs):
|
def rollout(self, input_ids, attention_mask, **kwargs):
|
||||||
rollouts = self.model.generate(input_ids, attention_mask, **kwargs)
|
rollouts = self.model.generate(input_ids, attention_mask, **kwargs)
|
||||||
if self.producer_idx == 0 and not self.eval_mode:
|
if self.producer_idx == 0 and not self.eval_mode:
|
||||||
wandb_rollout_data = self.wandb_rollout_data + [
|
|
||||||
[
|
|
||||||
str(self.consumer_global_step),
|
|
||||||
str(self.tokenizer.decode(rollouts["input_ids"][0][0], skip_special_tokens=True)),
|
|
||||||
]
|
|
||||||
]
|
|
||||||
if (
|
if (
|
||||||
self.consumer_global_step - self.latest_rollout_log_step >= self.wandb_log_rollout_interval
|
self.consumer_global_step - self.latest_rollout_log_step >= self.log_rollout_interval
|
||||||
or self.latest_rollout_log_step == -1
|
or self.latest_rollout_log_step == -1
|
||||||
):
|
):
|
||||||
self.wandb_rollout_data = wandb_rollout_data
|
new_record = (
|
||||||
self.latest_rollout_log_step = self.consumer_global_step
|
json.dumps(
|
||||||
self.wandb_run.log(
|
|
||||||
{
|
{
|
||||||
"rollout/rollout_examples": wandb.Table(
|
"train_step": self.consumer_global_step,
|
||||||
columns=["train_step", "rollout_examples"], data=wandb_rollout_data
|
"rollout": self.tokenizer.batch_decode(
|
||||||
)
|
rollouts["input_ids"][:, 0], skip_special_tokens=True
|
||||||
|
),
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
+ "\n"
|
||||||
|
)
|
||||||
|
self.rollout_log_file.write(new_record)
|
||||||
|
self.rollout_log_file.flush()
|
||||||
|
self.latest_rollout_log_step = self.consumer_global_step
|
||||||
return rollouts
|
return rollouts
|
||||||
|
|
||||||
|
def __del__(self):
|
||||||
|
if self.producer_idx == 0:
|
||||||
|
self.wandb_run.finish()
|
||||||
|
if hasattr(self, "rollout_log_file"):
|
||||||
|
self.rollout_log_file.close()
|
||||||
|
|
||||||
def load_state_dict(self, state_dict):
|
def load_state_dict(self, state_dict):
|
||||||
self.model.load_state_dict(state_dict)
|
self.model.load_state_dict(state_dict)
|
||||||
|
@ -1,7 +1,77 @@
|
|||||||
import torch
|
import torch
|
||||||
|
from latex2sympy2_extended import NormalizationConfig
|
||||||
|
from math_verify import ExprExtractionConfig, LatexExtractionConfig, parse, verify
|
||||||
|
|
||||||
from .reward_utils import extract_boxed_solution, extract_solution, validate_response_structure
|
from .reward_utils import extract_boxed_solution, extract_solution, validate_response_structure
|
||||||
|
|
||||||
|
CANNOT_PARSE_GT_ANSWER = -1
|
||||||
|
CANNOT_PARSE_PREDICTION = -2
|
||||||
|
SUCCESS = 1
|
||||||
|
MATCHING_FAIL = 0
|
||||||
|
|
||||||
|
|
||||||
|
def verify_math_representation(completion, gt_answer):
|
||||||
|
"""
|
||||||
|
Verify if the completion is a valid math representation of the gt_answer.
|
||||||
|
"""
|
||||||
|
if not completion.startswith("\\boxed{"):
|
||||||
|
completion = "\\boxed{" + completion + "}"
|
||||||
|
if not gt_answer.startswith("\\boxed{"):
|
||||||
|
gt_answer = "\\boxed{" + gt_answer + "}"
|
||||||
|
target = (
|
||||||
|
ExprExtractionConfig(),
|
||||||
|
LatexExtractionConfig(
|
||||||
|
normalization_config=NormalizationConfig(
|
||||||
|
nits=False,
|
||||||
|
malformed_operators=False,
|
||||||
|
basic_latex=True,
|
||||||
|
boxed="all",
|
||||||
|
units=True,
|
||||||
|
),
|
||||||
|
boxed_match_priority=0,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
if not isinstance(gt_answer, str) or len(gt_answer) == 0:
|
||||||
|
raise ValueError("gt_answer should be a string, please verify your training data.")
|
||||||
|
if not isinstance(completion, str) or len(completion) == 0:
|
||||||
|
return MATCHING_FAIL
|
||||||
|
try:
|
||||||
|
parsed_gt_answer = parse(gt_answer, extraction_config=target)
|
||||||
|
if len(parsed_gt_answer) == 0:
|
||||||
|
return CANNOT_PARSE_GT_ANSWER
|
||||||
|
parsed_completion = parse(completion, extraction_config=target)
|
||||||
|
if len(parsed_completion) == 0:
|
||||||
|
return CANNOT_PARSE_PREDICTION
|
||||||
|
if verify(parsed_gt_answer, parsed_completion):
|
||||||
|
return SUCCESS
|
||||||
|
else:
|
||||||
|
return MATCHING_FAIL
|
||||||
|
except Exception:
|
||||||
|
return MATCHING_FAIL
|
||||||
|
|
||||||
|
|
||||||
|
def verify_model_answer(decoded_final_answer, gt_answer, ans_acc, acc_score, reward):
|
||||||
|
math_verify_result = verify_math_representation(decoded_final_answer, gt_answer)
|
||||||
|
exact_match_result = (
|
||||||
|
SUCCESS
|
||||||
|
if decoded_final_answer.strip().replace(" ", "").replace("{", "").replace("}", "").replace(",", "")
|
||||||
|
== gt_answer.strip().replace(" ", "").replace("{", "").replace("}", "").replace(",", "")
|
||||||
|
else MATCHING_FAIL
|
||||||
|
)
|
||||||
|
if math_verify_result == SUCCESS:
|
||||||
|
ans_acc += 1
|
||||||
|
reward += acc_score
|
||||||
|
elif exact_match_result == SUCCESS:
|
||||||
|
# sometimes for answers that's not a (valid) math expression, math_verify will fail
|
||||||
|
ans_acc += 1
|
||||||
|
if math_verify_result == CANNOT_PARSE_PREDICTION:
|
||||||
|
reward += (
|
||||||
|
acc_score / 2
|
||||||
|
) # not a valid latex math representation, but the answer is correct, receive half of the score
|
||||||
|
else:
|
||||||
|
reward += acc_score
|
||||||
|
return reward, ans_acc
|
||||||
|
|
||||||
|
|
||||||
def math_reward_fn(input_ids, gt_answer, response_idx, **kwargs):
|
def math_reward_fn(input_ids, gt_answer, response_idx, **kwargs):
|
||||||
tokenizer = kwargs["tokenizer"]
|
tokenizer = kwargs["tokenizer"]
|
||||||
@ -14,15 +84,18 @@ def math_reward_fn(input_ids, gt_answer, response_idx, **kwargs):
|
|||||||
s, e = response_idx[0], response_idx[1]
|
s, e = response_idx[0], response_idx[1]
|
||||||
|
|
||||||
length_reward = 0.0
|
length_reward = 0.0
|
||||||
if soft_over_length_punishment:
|
|
||||||
max_length = kwargs.get("max_length", 1024 * 4)
|
|
||||||
cache_length = kwargs.get("cache_length", 512)
|
|
||||||
res_length = e.item() - s.item() + 1
|
res_length = e.item() - s.item() + 1
|
||||||
if max_length - cache_length < res_length < max_length:
|
if not eval_mode:
|
||||||
length_reward = ((max_length - cache_length) - res_length) / cache_length * acc_score
|
max_new_tokens = kwargs["max_new_tokens"]
|
||||||
|
else:
|
||||||
|
max_new_tokens = -1 # for eval mode, we don't need to check the length
|
||||||
|
if not eval_mode and soft_over_length_punishment:
|
||||||
|
cache_length = kwargs["cache_length"]
|
||||||
|
if max_new_tokens - cache_length < res_length < max_new_tokens:
|
||||||
|
length_reward = ((max_new_tokens - cache_length) - res_length) / cache_length * acc_score
|
||||||
|
|
||||||
if gt_answer is None:
|
if gt_answer is None:
|
||||||
return reward
|
raise ValueError("no gt_answer is provided, please check your training dataset.")
|
||||||
|
|
||||||
decoded_final_answer = tokenizer.decode(input_ids[s : e + 1], skip_special_tokens=True)
|
decoded_final_answer = tokenizer.decode(input_ids[s : e + 1], skip_special_tokens=True)
|
||||||
gt_answer = tokenizer.decode(gt_answer.squeeze(0), skip_special_tokens=True)
|
gt_answer = tokenizer.decode(gt_answer.squeeze(0), skip_special_tokens=True)
|
||||||
@ -35,16 +108,16 @@ def math_reward_fn(input_ids, gt_answer, response_idx, **kwargs):
|
|||||||
format_acc += 1
|
format_acc += 1
|
||||||
|
|
||||||
# Check answer accuracy, answer is considered correct if the answer is correct and the format is valid
|
# Check answer accuracy, answer is considered correct if the answer is correct and the format is valid
|
||||||
if (
|
if final_answer is not None:
|
||||||
format_valid
|
if eval_mode or format_valid:
|
||||||
and final_answer is not None
|
reward, ans_acc = verify_model_answer(final_answer, gt_answer, ans_acc, acc_score, reward)
|
||||||
and gt_answer.strip().replace(" ", "").lower() == final_answer.strip().replace(" ", "").lower()
|
if not eval_mode:
|
||||||
):
|
|
||||||
ans_acc += 1
|
|
||||||
reward += acc_score
|
|
||||||
|
|
||||||
reward = reward + length_reward
|
reward = reward + length_reward
|
||||||
|
|
||||||
|
# Check if the sequence is over length
|
||||||
|
if not eval_mode and res_length >= max_new_tokens:
|
||||||
|
reward *= 0.0
|
||||||
|
|
||||||
if not eval_mode:
|
if not eval_mode:
|
||||||
return torch.tensor([reward, format_acc, ans_acc]).to(input_ids.device)
|
return torch.tensor([reward, format_acc, ans_acc]).to(input_ids.device)
|
||||||
else:
|
else:
|
||||||
@ -56,6 +129,8 @@ def math_reward_fn(input_ids, gt_answer, response_idx, **kwargs):
|
|||||||
"parsed": final_answer,
|
"parsed": final_answer,
|
||||||
"format_valid": format_acc.item(),
|
"format_valid": format_acc.item(),
|
||||||
"ans_valid": ans_acc.item(),
|
"ans_valid": ans_acc.item(),
|
||||||
|
"response_length": res_length,
|
||||||
|
"reward": reward.item(),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@ -71,31 +146,45 @@ def boxed_math_reward_fn(input_ids, gt_answer, response_idx, **kwargs):
|
|||||||
s, e = response_idx[0], response_idx[1]
|
s, e = response_idx[0], response_idx[1]
|
||||||
|
|
||||||
length_reward = 0.0
|
length_reward = 0.0
|
||||||
if soft_over_length_punishment:
|
|
||||||
max_length = kwargs.get("max_length", 1024 * 4)
|
|
||||||
cache_length = kwargs.get("cache_length", 512)
|
|
||||||
res_length = e.item() - s.item() + 1
|
res_length = e.item() - s.item() + 1
|
||||||
if max_length - cache_length < res_length < max_length:
|
if not eval_mode:
|
||||||
length_reward = ((max_length - cache_length) - res_length) / cache_length * acc_score
|
max_new_tokens = kwargs["max_new_tokens"]
|
||||||
|
else:
|
||||||
|
max_new_tokens = -1 # for eval mode, we don't need to check the length
|
||||||
|
if not eval_mode and soft_over_length_punishment:
|
||||||
|
cache_length = kwargs["cache_length"]
|
||||||
|
if max_new_tokens - cache_length < res_length < max_new_tokens:
|
||||||
|
length_reward = ((max_new_tokens - cache_length) - res_length) / cache_length * acc_score
|
||||||
|
|
||||||
if gt_answer is None:
|
if gt_answer is None:
|
||||||
return torch.tensor([reward, format_acc, ans_acc]).to(input_ids.device)
|
raise ValueError("no gt_answer is provided, please check your training dataset.")
|
||||||
|
|
||||||
decoded_final_answer = tokenizer.decode(input_ids[s : e + 1], skip_special_tokens=True)
|
decoded_final_answer = tokenizer.decode(input_ids[s : e + 1], skip_special_tokens=True)
|
||||||
|
|
||||||
gt_answer = tokenizer.decode(gt_answer.squeeze(0), skip_special_tokens=True)
|
gt_answer = tokenizer.decode(gt_answer.squeeze(0), skip_special_tokens=True)
|
||||||
final_answer = extract_boxed_solution(decoded_final_answer)
|
final_answer = extract_boxed_solution(decoded_final_answer)
|
||||||
format_valid = final_answer is not None
|
format_valid = final_answer is not None
|
||||||
|
if "tags" in kwargs and kwargs["tags"]:
|
||||||
|
tags = kwargs["tags"]
|
||||||
|
format_valid = format_valid and all(
|
||||||
|
[decoded_final_answer.count(tags[tag]["text"]) == tags[tag]["num_occur"] for tag in tags]
|
||||||
|
)
|
||||||
# Check format accuracy
|
# Check format accuracy
|
||||||
if format_valid:
|
if format_valid:
|
||||||
format_acc += 1
|
format_acc += 1
|
||||||
reward += format_score
|
reward += format_score
|
||||||
|
|
||||||
# Check answer accuracy, answer is considered correct if the answer is correct and the format is valid
|
# Check answer accuracy, answer is considered correct if the answer is correct and the format is valid
|
||||||
if format_valid and final_answer is not None and gt_answer.strip().lower() == final_answer.strip().lower():
|
if final_answer is not None:
|
||||||
ans_acc += 1
|
if eval_mode or format_valid:
|
||||||
reward += acc_score
|
reward, ans_acc = verify_model_answer(final_answer, gt_answer, ans_acc, acc_score, reward)
|
||||||
|
if not eval_mode:
|
||||||
reward = reward + length_reward
|
reward = reward + length_reward
|
||||||
|
|
||||||
|
# Check if the sequence is over length
|
||||||
|
if not eval_mode and res_length >= max_new_tokens:
|
||||||
|
reward *= 0.0
|
||||||
|
|
||||||
if not eval_mode:
|
if not eval_mode:
|
||||||
return torch.tensor([reward, format_acc, ans_acc]).to(input_ids.device)
|
return torch.tensor([reward, format_acc, ans_acc]).to(input_ids.device)
|
||||||
else:
|
else:
|
||||||
@ -107,4 +196,6 @@ def boxed_math_reward_fn(input_ids, gt_answer, response_idx, **kwargs):
|
|||||||
"parsed": final_answer,
|
"parsed": final_answer,
|
||||||
"format_valid": format_acc.item(),
|
"format_valid": format_acc.item(),
|
||||||
"ans_valid": ans_acc.item(),
|
"ans_valid": ans_acc.item(),
|
||||||
|
"response_length": res_length,
|
||||||
|
"reward": reward.item(),
|
||||||
}
|
}
|
||||||
|
@ -118,6 +118,9 @@ if __name__ == "__main__":
|
|||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"-esd", "--eval-save-dir", type=str, default="./eval", help="Directory for saving evaluation results."
|
"-esd", "--eval-save-dir", type=str, default="./eval", help="Directory for saving evaluation results."
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"-rsd", "--rollout-save-dir", type=str, default="./rollouts", help="Directory for saving rollout loggings."
|
||||||
|
)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
if args.train_minibatch_size is None:
|
if args.train_minibatch_size is None:
|
||||||
@ -198,6 +201,8 @@ if __name__ == "__main__":
|
|||||||
"beta": args.kl_coeff, # KL penalty coefficient
|
"beta": args.kl_coeff, # KL penalty coefficient
|
||||||
"loss_variation": "sample_level",
|
"loss_variation": "sample_level",
|
||||||
"reward_fn_type": args.reward_type,
|
"reward_fn_type": args.reward_type,
|
||||||
|
"max_length": args.max_new_tokens + args.max_prompt_tokens,
|
||||||
|
"max_new_tokens": args.max_new_tokens,
|
||||||
}
|
}
|
||||||
elif args.algo == "DAPO":
|
elif args.algo == "DAPO":
|
||||||
# DAPO variant settings
|
# DAPO variant settings
|
||||||
@ -213,6 +218,7 @@ if __name__ == "__main__":
|
|||||||
"loss_variation": "token_level",
|
"loss_variation": "token_level",
|
||||||
"soft_over_length_punishment": True,
|
"soft_over_length_punishment": True,
|
||||||
"max_length": args.max_new_tokens + args.max_prompt_tokens,
|
"max_length": args.max_new_tokens + args.max_prompt_tokens,
|
||||||
|
"max_new_tokens": args.max_new_tokens,
|
||||||
"cache_length": min(1024, int(args.max_new_tokens / 4)),
|
"cache_length": min(1024, int(args.max_new_tokens / 4)),
|
||||||
"filter_truncated_response": True,
|
"filter_truncated_response": True,
|
||||||
"reward_fn_type": args.reward_type,
|
"reward_fn_type": args.reward_type,
|
||||||
@ -266,4 +272,6 @@ if __name__ == "__main__":
|
|||||||
eval_interval=args.eval_interval,
|
eval_interval=args.eval_interval,
|
||||||
eval_save_dir=os.path.join(args.eval_save_dir, args.project.replace(" ", "_")),
|
eval_save_dir=os.path.join(args.eval_save_dir, args.project.replace(" ", "_")),
|
||||||
eval_generation_config=eval_generation_config,
|
eval_generation_config=eval_generation_config,
|
||||||
|
log_rollout_interval=20,
|
||||||
|
rollout_save_dir=args.rollout_save_dir,
|
||||||
)
|
)
|
||||||
|
Loading…
Reference in New Issue
Block a user