[feat] Support DAPO (#6263)

* update help information

* update style

* fix

* minor fix

* support PP training

* add pp support

* remove unused code

* address conversation

* fix memory leakage support tp+pp

* move empty cache

* move empty cache

* add DAPO support

* remove format reward

* fix filtering, still buggy

* small fix

* add DAPO support

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* tested multi-node training; fix bind_batch bug

* fix conversation; support sleep mode

* support reusing excessive samples

* add dynamic batching control flag

* add dynamic batching control flag

* refactored

* fix logging

---------

Co-authored-by: Tong Li <tong.li35271158@gmail.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
YeAnbang
2025-04-25 17:39:17 +08:00
committed by GitHub
parent b823c6eec7
commit 26d859f68e
10 changed files with 552 additions and 359 deletions

View File

@@ -4,13 +4,22 @@ from .reward_utils import extract_solution, validate_response_structure
def math_reward_fn(input_ids, gt_answer, response_idx, **kwargs):
format_score = 1.0
acc_score = 9.0
tokenizer = kwargs["tokenizer"]
soft_over_length_punishment = kwargs.get("soft_over_length_punishment", False)
acc_score = 10.0
reward = torch.tensor(0.0)
format_reward = torch.tensor(0.0)
acc_reward = torch.tensor(0.0)
format_acc = torch.tensor(0.0)
ans_acc = torch.tensor(0.0)
s, e = response_idx[0], response_idx[1]
length_reward = 0.0
if soft_over_length_punishment:
max_length = kwargs.get("max_length", 1024 * 4)
cache_length = kwargs.get("cache_length", 512)
res_length = e.item() - s.item() + 1
if max_length - cache_length < res_length < max_length:
length_reward = ((max_length - cache_length) - res_length) / cache_length * acc_score
if gt_answer is None:
return reward
@@ -22,18 +31,20 @@ def math_reward_fn(input_ids, gt_answer, response_idx, **kwargs):
# Check format accuracy
if format_valid:
format_reward += format_score
reward += format_score
format_acc += 1
# Check answer accuracy
# Check answer accuracy, answer is considered correct if the answer is correct and the format is valid
if (
final_answer is not None
format_valid
and final_answer is not None
and gt_answer.strip().replace(" ", "").lower() == final_answer.strip().replace(" ", "").lower()
):
acc_reward += acc_score
ans_acc += 1
reward += acc_score
return torch.tensor([reward, format_reward, acc_reward]).to(input_ids.device)
reward = reward + length_reward
return torch.tensor([reward, format_acc, ans_acc]).to(input_ids.device)
def gsm8k_reward_fn(input_ids, **kwargs):