[feat] Support DAPO (#6263)

* update help information

* update style

* fix

* minor fix

* support PP training

* add pp support

* remove unused code

* address conversation

* fix memory leakage support tp+pp

* move empty cache

* move empty cache

* add DAPO support

* remove format reward

* fix filtering, still buggy

* small fix

* add DAPO support

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* tested multi-node training; fix bind_batch bug

* fix conversation; support sleep mode

* support reusing excessive samples

* add dynamic batching control flag

* add dynamic batching control flag

* refactored

* fix logging

---------

Co-authored-by: Tong Li <tong.li35271158@gmail.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
YeAnbang
2025-04-25 17:39:17 +08:00
committed by GitHub
parent b823c6eec7
commit 26d859f68e
10 changed files with 552 additions and 359 deletions

View File

@@ -4,10 +4,10 @@ from typing import Any, Dict, Optional
import ray
from .consumer import SimpleConsumer
from .grpo_consumer import GRPOConsumer, GRPOEvalConsumer
from .grpo_consumer import GRPOConsumer
from .producer import SimpleProducer
ALGO_MAP = {"Simple": SimpleConsumer, "GRPO": GRPOConsumer, "EvalGRPO": GRPOEvalConsumer}
ALGO_MAP = {"Simple": SimpleConsumer, "GRPO": GRPOConsumer, "DAPO": GRPOConsumer}
def get_jsonl_size_fast(path: str) -> int:
@@ -40,6 +40,7 @@ def launch_distributed(
inference_model_config: Dict[str, Any],
generate_config: Dict[str, Any],
train_model_config: Dict[str, Any],
grpo_config: Dict[str, Any],
plugin_config: Dict[str, Any],
tokenizer_config: Optional[Dict[str, Any]] = None,
inference_backend: str = "transformers",
@@ -48,6 +49,8 @@ def launch_distributed(
master_port: int = 29500,
core_algo: str = "GRPO",
project_name: Optional[str] = None,
save_interval: int = 100,
save_dir: str = "./model",
):
if core_algo not in ALGO_MAP:
@@ -101,15 +104,13 @@ def launch_distributed(
batch_size=train_batch_size,
model_config=train_model_config,
plugin_config=plugin_config,
microbatch_size=train_minibatch_size,
minibatch_size=train_minibatch_size,
generate_config=generate_config_consumer,
training_config={
"filter_range": [0.05, 9.0],
"lr": 1e-6,
"train_microbatch_size": train_microbatch_size,
},
grpo_config=grpo_config,
num_generations=num_generations,
project_name=project_name,
save_interval=save_interval,
save_dir=save_dir,
)
procs.append(consumer)
ray.get([p.setup.remote() for p in procs])