ColossalAI/applications/ColossalChat/rl_example.py
2025-04-23 14:43:54 +08:00

180 lines
6.8 KiB
Python

import argparse
import ray
import torch
from coati.distributed.launch import launch_distributed
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("-m", "--model", type=str, default="Qwen/Qwen2.5-7B")
parser.add_argument("-d", "--dataset", type=str, default="data.jsonl")
parser.add_argument("-t", "--num-trainers", type=int, default=2)
parser.add_argument("-i", "--num-inferencer", type=int, default=2)
parser.add_argument("-g", "--num-generations", type=int, default=8, help="Number of generations per prompt.")
parser.add_argument("-p", "--project", type=str, default="GRPO", help="Project name.")
parser.add_argument(
"-ibs",
"--inference-batch-size",
type=int,
default=64,
help="Number of prompts to generate per inference step. It should be divisible by tbs, and the weights on the inference backend will be synced every ibs/tbs training steps of the policy model.",
)
parser.add_argument(
"-imbs",
"--inference-microbatch-size",
type=int,
default=8,
help="Effective batch size for the inference backend to run generation. Please select based on memory constraint.",
)
parser.add_argument(
"-tbs",
"--train-batch-size",
type=int,
default=32,
help="Number of unique prompts to update policy model per step per dp group. Gradient is accumulated across tbs * dp_size unique prompts, equivalently tbs * g * dp_size samples",
)
parser.add_argument(
"-tMbs",
"--train-minibatch-size",
type=int,
default=1,
help="Number of unique prompts in each training batch per dp group. The inference backend must generate tMbs * g * dp_size samples before forwarding. Satisfy tMbs * g >= tmbs",
)
parser.add_argument(
"-tmbs",
"--train-microbatch-size",
type=int,
default=2,
help="Effective batch size per dp group for forwarding and backwarding. Please select based on the availiable memory.",
)
parser.add_argument("-b", "--backend", type=str, default="transformers", choices=["transformers", "vllm"])
parser.add_argument("-a", "--algo", type=str, default="GRPO", choices=["Simple", "GRPO", "EvalGRPO"])
parser.add_argument(
"--ray_dir", type=str, default=None, help="Custom temperary directory for storing ray cluster data, Optional"
)
parser.add_argument(
"--master_address", type=str, default=None, help="Master address for multi-node distributed training, Optional"
)
args = parser.parse_args()
assert args.train_minibatch_size > 0, "Train mini batch size must be greater than 0"
assert (
args.train_minibatch_size * args.num_generations >= args.train_microbatch_size
and args.train_microbatch_size > 0
), "Train micro batch size must be greater than 0 less than train mini batch size * num generations"
assert args.train_minibatch_size < args.train_batch_size, "Train mini batch size must be less than train batch size"
if args.master_address is None:
# Default settings: Using single machine
ray.init(address="local", namespace="ray-example")
else:
# For ray distributed multi-machine training, Please change _node_ip_address to your IP address of your master node
ray.init(_node_ip_address=args.master_address, namespace="ray-example", _temp_dir=args.ray_dir)
inference_model_config = dict(path=args.model)
train_model_config = dict(path=args.model, use_flash_attention_2=True, use_cache=False)
generate_config = dict(top_k=-1, top_p=1.0, temperature=1.0)
if args.backend == "transformers":
inference_model_config.update(
dict(
use_flash_attention_2=True,
torch_dtype=torch.bfloat16,
)
)
generate_config.update(
dict(
max_length=1024 + 512,
do_sample=True,
max_new_tokens=None,
early_stopping=False,
stop_strings=["</answer>"],
)
)
elif args.backend == "vllm":
inference_model_config.update(
dict(
gpu_memory_utilization=0.7,
enforce_eager=True,
enable_chunked_prefill=True,
max_model_len=1024 * 4 + 510,
tensor_parallel_size=1,
)
)
generate_config.update(
dict(
max_tokens=1024 * 4,
ignore_eos=True,
include_stop_str_in_output=True,
stop=["</answer>"],
)
)
else:
inference_model_config.update(
dict(
mem_fraction_static=0.6,
)
)
generate_config.update(
dict(
max_new_tokens=256,
ignore_eos=True,
)
)
# Default Settings
# grpo_config = {
# "filter_range": [0.05, 9.0],
# "lr": 1e-6,
# "train_microbatch_size": train_microbatch_size,
# }
# DAPO variant settings
grpo_config = {
"filter_range": [0.01, 0.99], # only filter out all zero batch and all one batch
"lr": 1e-6,
"train_microbatch_size": args.train_microbatch_size,
"clip_eps_low": 0.2,
"clip_eps_high": 0.28,
"skip_threshold": 20.0,
"beta": 0.0, # no KL penalty
"loss_variation": "token_level",
"soft_over_length_punishment": True,
"max_length": 1024 * 4,
"cache_length": 512,
"filter_truncated_response": True,
}
launch_distributed(
num_producers=args.num_inferencer,
num_proc_per_producer=inference_model_config.get("tensor_parallel_size", 1),
num_consumer_procs=args.num_trainers,
num_episodes=10,
inference_batch_size=args.inference_batch_size,
inference_microbatch_size=args.inference_microbatch_size,
train_batch_size=args.train_batch_size,
train_minibatch_size=args.train_minibatch_size,
train_microbatch_size=args.train_microbatch_size,
dataset_config={"path": args.dataset, "max_length": 300},
dataloaders_config={},
inference_model_config=inference_model_config,
generate_config=generate_config,
num_generations=args.num_generations,
train_model_config=train_model_config,
grpo_config=grpo_config,
plugin_config={
"zero_stage": 2,
}, # for zero
# plugin_config={
# "tp_size": 2,
# "microbatch_size": args.train_microbatch_size // 2,
# "zero_stage": 0,
# "max_norm": 1.0,
# }, # for pp
inference_backend=args.backend,
master_addr="localhost",
master_port=29506,
core_algo=args.algo,
project_name=args.project,
)