mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-04-28 11:45:23 +00:00
* add reward related function * add simple grpo * update grpo * polish * modify data loader * grpo consumer * update loss * update reward fn * update example * update loader * add algo selection * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * add save * update select algo * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * update grpo * update reward fn * update reward * fix reward score * add response length * detach * fix tp bug * fix consumer * convert to 8 generation * print results * setup update * fix transformers backend * [Feature] Support Distributed LogProb for GRPO Training (#6247) * [fix] fix qwen VocabParallelLMHead1D and gather output * fix tp bug * fix consumer * [feat] Support Distributed LogProb for GRPO Training * [fix] fix loss func * [fix] fix log prob plugin * [fix] fix qwen modeling param * [fix] rm comments * [fix] rm hard-code;fix non-dist version * [fix] fix test file param name and benchmark tp gather output=True/False * [fix] rm non-dist version in dist log prob * [fix] fix comments * [fix] fix dis log prob plugin * [fix] fix test case * [fix] fix qwen VocabParallelLMHead1D and gather output * [fix] fix DistLogProb comments * [fix] restore tp size * [fix] fix comments * [fix] fix comment; fix LogSoftmax usage --------- Co-authored-by: Tong Li <tong.li35271158@gmail.com> * fix vllm * fix logprob, add filtering, temperature annealing, lr descent * simplify vllm preprocessing input ids * update logging * [feat] add microbatch forwarding (#6251) * add microbatch forwarding * fix forward microbatch * fix producer OOM * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * change project name * fix temperature annealing * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * address conversation --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> * [Distributed RLHF] Integration of PP (#6257) * update help information * update style * fix * minor fix * support PP training * add pp support * remove unused code * address conversation --------- Co-authored-by: Tong Li <tong.li35271158@gmail.com> * [hot-fix] Fix memory leakage bug, support TP+PP (#6258) * update help information * update style * fix * minor fix * support PP training * add pp support * remove unused code * address conversation * fix memory leakage support tp+pp * move empty cache * move empty cache --------- Co-authored-by: Tong Li <tong.li35271158@gmail.com> --------- Co-authored-by: Tong Li <tong.li35271158@gmail.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: YeAnbang <anbangy2@outlook.com> Co-authored-by: duanjunwen <935724073@qq.com> Co-authored-by: YeAnbang <44796419+YeAnbang@users.noreply.github.com>
135 lines
5.1 KiB
Python
135 lines
5.1 KiB
Python
import argparse
|
|
|
|
import ray
|
|
import torch
|
|
from coati.distributed.launch import launch_distributed
|
|
|
|
if __name__ == "__main__":
|
|
parser = argparse.ArgumentParser()
|
|
parser.add_argument("-m", "--model", type=str, default="Qwen/Qwen2.5-7B")
|
|
parser.add_argument("-d", "--dataset", type=str, default="data.jsonl")
|
|
parser.add_argument("-t", "--num-trainers", type=int, default=2)
|
|
parser.add_argument("-i", "--num-inferencer", type=int, default=2)
|
|
parser.add_argument("-g", "--num-generations", type=int, default=8, help="Number of generations per prompt.")
|
|
parser.add_argument("-p", "--project", type=str, default="GRPO", help="Project name.")
|
|
parser.add_argument(
|
|
"-ibs",
|
|
"--inference-batch-size",
|
|
type=int,
|
|
default=64,
|
|
help="Number of prompts to generate per inference step. It should be divisible by tbs, and the weights on the inference backend will be synced every ibs/tbs training steps of the policy model.",
|
|
)
|
|
parser.add_argument(
|
|
"-imbs",
|
|
"--inference-microbatch-size",
|
|
type=int,
|
|
default=8,
|
|
help="Effective batch size for the inference backend to run generation. Please select based on memory constraint.",
|
|
)
|
|
parser.add_argument(
|
|
"-tbs",
|
|
"--train-batch-size",
|
|
type=int,
|
|
default=32,
|
|
help="Number of unique prompts to update policy model per step per dp group. Gradient is accumulated across tbs * dp_size unique prompts, equivalently tbs * g * dp_size samples",
|
|
)
|
|
parser.add_argument(
|
|
"-tMbs",
|
|
"--train-minibatch-size",
|
|
type=int,
|
|
default=1,
|
|
help="Number of unique prompts in each training batch per dp group. The inference backend must generate tMbs * g * dp_size samples before forwarding. Satisfy tMbs * g >= tmbs",
|
|
)
|
|
parser.add_argument(
|
|
"-tmbs",
|
|
"--train-microbatch-size",
|
|
type=int,
|
|
default=2,
|
|
help="Effective batch size per dp group for forwarding and backwarding. Please select based on the availiable memory.",
|
|
)
|
|
parser.add_argument("-b", "--backend", type=str, default="transformers", choices=["transformers", "vllm"])
|
|
parser.add_argument("-a", "--algo", type=str, default="GRPO", choices=["Simple", "GRPO", "EvalGRPO"])
|
|
args = parser.parse_args()
|
|
|
|
assert args.train_minibatch_size > 0, "Train mini batch size must be greater than 0"
|
|
assert (
|
|
args.train_minibatch_size * args.num_generations >= args.train_microbatch_size
|
|
and args.train_microbatch_size > 0
|
|
), "Train micro batch size must be greater than 0 less than train mini batch size * num generations"
|
|
|
|
ray.init(address="local", namespace="ray-example")
|
|
|
|
inference_model_config = dict(path=args.model)
|
|
train_model_config = dict(path=args.model, use_flash_attention_2=True, use_cache=False)
|
|
generate_config = dict(top_k=50, top_p=0.75, temperature=0.9)
|
|
|
|
if args.backend == "transformers":
|
|
inference_model_config.update(
|
|
dict(
|
|
use_flash_attention_2=True,
|
|
torch_dtype=torch.bfloat16,
|
|
)
|
|
)
|
|
generate_config.update(
|
|
dict(
|
|
max_length=1024 + 512,
|
|
do_sample=True,
|
|
max_new_tokens=None,
|
|
early_stopping=False,
|
|
stop_strings=["</answer>"],
|
|
)
|
|
)
|
|
elif args.backend == "vllm":
|
|
inference_model_config.update(dict(gpu_memory_utilization=0.7, enforce_eager=True, enable_chunked_prefill=True))
|
|
generate_config.update(
|
|
dict(
|
|
max_tokens=2048,
|
|
ignore_eos=True,
|
|
include_stop_str_in_output=True,
|
|
stop=["</answer>"],
|
|
)
|
|
)
|
|
else:
|
|
inference_model_config.update(
|
|
dict(
|
|
mem_fraction_static=0.6,
|
|
)
|
|
)
|
|
generate_config.update(
|
|
dict(
|
|
max_new_tokens=256,
|
|
ignore_eos=True,
|
|
)
|
|
)
|
|
|
|
launch_distributed(
|
|
num_producers=args.num_inferencer,
|
|
num_proc_per_producer=1,
|
|
num_consumer_procs=args.num_trainers,
|
|
num_episodes=10,
|
|
inference_batch_size=args.inference_batch_size,
|
|
inference_microbatch_size=args.inference_microbatch_size,
|
|
train_batch_size=args.train_batch_size,
|
|
train_minibatch_size=args.train_minibatch_size,
|
|
train_microbatch_size=args.train_microbatch_size,
|
|
dataset_config={"path": args.dataset, "max_length": 300},
|
|
dataloaders_config={},
|
|
inference_model_config=inference_model_config,
|
|
generate_config=generate_config,
|
|
num_generations=args.num_generations,
|
|
train_model_config=train_model_config,
|
|
# plugin_config={}, # for zero
|
|
plugin_config={
|
|
"pp_size": 2,
|
|
"tp_size": 2,
|
|
"microbatch_size": args.train_microbatch_size // 2,
|
|
"zero_stage": 0,
|
|
"max_norm": 1.0,
|
|
}, # for pp
|
|
inference_backend=args.backend,
|
|
master_addr="localhost",
|
|
master_port=29506,
|
|
core_algo=args.algo,
|
|
project_name=args.project,
|
|
)
|