mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-05 11:02:05 +00:00
[feat] Support DAPO (#6263)
* update help information * update style * fix * minor fix * support PP training * add pp support * remove unused code * address conversation * fix memory leakage support tp+pp * move empty cache * move empty cache * add DAPO support * remove format reward * fix filtering, still buggy * small fix * add DAPO support * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * tested multi-node training; fix bind_batch bug * fix conversation; support sleep mode * support reusing excessive samples * add dynamic batching control flag * add dynamic batching control flag * refactored * fix logging --------- Co-authored-by: Tong Li <tong.li35271158@gmail.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
@@ -4,13 +4,22 @@ from .reward_utils import extract_solution, validate_response_structure
|
||||
|
||||
|
||||
def math_reward_fn(input_ids, gt_answer, response_idx, **kwargs):
|
||||
format_score = 1.0
|
||||
acc_score = 9.0
|
||||
tokenizer = kwargs["tokenizer"]
|
||||
soft_over_length_punishment = kwargs.get("soft_over_length_punishment", False)
|
||||
acc_score = 10.0
|
||||
reward = torch.tensor(0.0)
|
||||
format_reward = torch.tensor(0.0)
|
||||
acc_reward = torch.tensor(0.0)
|
||||
format_acc = torch.tensor(0.0)
|
||||
ans_acc = torch.tensor(0.0)
|
||||
s, e = response_idx[0], response_idx[1]
|
||||
|
||||
length_reward = 0.0
|
||||
if soft_over_length_punishment:
|
||||
max_length = kwargs.get("max_length", 1024 * 4)
|
||||
cache_length = kwargs.get("cache_length", 512)
|
||||
res_length = e.item() - s.item() + 1
|
||||
if max_length - cache_length < res_length < max_length:
|
||||
length_reward = ((max_length - cache_length) - res_length) / cache_length * acc_score
|
||||
|
||||
if gt_answer is None:
|
||||
return reward
|
||||
|
||||
@@ -22,18 +31,20 @@ def math_reward_fn(input_ids, gt_answer, response_idx, **kwargs):
|
||||
|
||||
# Check format accuracy
|
||||
if format_valid:
|
||||
format_reward += format_score
|
||||
reward += format_score
|
||||
format_acc += 1
|
||||
|
||||
# Check answer accuracy
|
||||
# Check answer accuracy, answer is considered correct if the answer is correct and the format is valid
|
||||
if (
|
||||
final_answer is not None
|
||||
format_valid
|
||||
and final_answer is not None
|
||||
and gt_answer.strip().replace(" ", "").lower() == final_answer.strip().replace(" ", "").lower()
|
||||
):
|
||||
acc_reward += acc_score
|
||||
ans_acc += 1
|
||||
reward += acc_score
|
||||
|
||||
return torch.tensor([reward, format_reward, acc_reward]).to(input_ids.device)
|
||||
reward = reward + length_reward
|
||||
|
||||
return torch.tensor([reward, format_acc, ans_acc]).to(input_ids.device)
|
||||
|
||||
|
||||
def gsm8k_reward_fn(input_ids, **kwargs):
|
||||
|
Reference in New Issue
Block a user