mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-20 17:10:03 +00:00
[feat] Support DAPO (#6263)
* update help information * update style * fix * minor fix * support PP training * add pp support * remove unused code * address conversation * fix memory leakage support tp+pp * move empty cache * move empty cache * add DAPO support * remove format reward * fix filtering, still buggy * small fix * add DAPO support * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * tested multi-node training; fix bind_batch bug * fix conversation; support sleep mode * support reusing excessive samples * add dynamic batching control flag * add dynamic batching control flag * refactored * fix logging --------- Co-authored-by: Tong Li <tong.li35271158@gmail.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
@@ -4,10 +4,10 @@ from typing import Any, Dict, Optional
|
||||
import ray
|
||||
|
||||
from .consumer import SimpleConsumer
|
||||
from .grpo_consumer import GRPOConsumer, GRPOEvalConsumer
|
||||
from .grpo_consumer import GRPOConsumer
|
||||
from .producer import SimpleProducer
|
||||
|
||||
ALGO_MAP = {"Simple": SimpleConsumer, "GRPO": GRPOConsumer, "EvalGRPO": GRPOEvalConsumer}
|
||||
ALGO_MAP = {"Simple": SimpleConsumer, "GRPO": GRPOConsumer, "DAPO": GRPOConsumer}
|
||||
|
||||
|
||||
def get_jsonl_size_fast(path: str) -> int:
|
||||
@@ -40,6 +40,7 @@ def launch_distributed(
|
||||
inference_model_config: Dict[str, Any],
|
||||
generate_config: Dict[str, Any],
|
||||
train_model_config: Dict[str, Any],
|
||||
grpo_config: Dict[str, Any],
|
||||
plugin_config: Dict[str, Any],
|
||||
tokenizer_config: Optional[Dict[str, Any]] = None,
|
||||
inference_backend: str = "transformers",
|
||||
@@ -48,6 +49,8 @@ def launch_distributed(
|
||||
master_port: int = 29500,
|
||||
core_algo: str = "GRPO",
|
||||
project_name: Optional[str] = None,
|
||||
save_interval: int = 100,
|
||||
save_dir: str = "./model",
|
||||
):
|
||||
|
||||
if core_algo not in ALGO_MAP:
|
||||
@@ -101,15 +104,13 @@ def launch_distributed(
|
||||
batch_size=train_batch_size,
|
||||
model_config=train_model_config,
|
||||
plugin_config=plugin_config,
|
||||
microbatch_size=train_minibatch_size,
|
||||
minibatch_size=train_minibatch_size,
|
||||
generate_config=generate_config_consumer,
|
||||
training_config={
|
||||
"filter_range": [0.05, 9.0],
|
||||
"lr": 1e-6,
|
||||
"train_microbatch_size": train_microbatch_size,
|
||||
},
|
||||
grpo_config=grpo_config,
|
||||
num_generations=num_generations,
|
||||
project_name=project_name,
|
||||
save_interval=save_interval,
|
||||
save_dir=save_dir,
|
||||
)
|
||||
procs.append(consumer)
|
||||
ray.get([p.setup.remote() for p in procs])
|
||||
|
Reference in New Issue
Block a user