mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-05-02 05:35:29 +00:00
* add reward related function * add simple grpo * update grpo * polish * modify data loader * grpo consumer * update loss * update reward fn * update example * update loader * add algo selection * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * add save * update select algo * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * update grpo * update reward fn * update reward * fix reward score * add response length * detach * fix tp bug * fix consumer * convert to 8 generation * print results * setup update * fix transformers backend * [Feature] Support Distributed LogProb for GRPO Training (#6247) * [fix] fix qwen VocabParallelLMHead1D and gather output * fix tp bug * fix consumer * [feat] Support Distributed LogProb for GRPO Training * [fix] fix loss func * [fix] fix log prob plugin * [fix] fix qwen modeling param * [fix] rm comments * [fix] rm hard-code;fix non-dist version * [fix] fix test file param name and benchmark tp gather output=True/False * [fix] rm non-dist version in dist log prob * [fix] fix comments * [fix] fix dis log prob plugin * [fix] fix test case * [fix] fix qwen VocabParallelLMHead1D and gather output * [fix] fix DistLogProb comments * [fix] restore tp size * [fix] fix comments * [fix] fix comment; fix LogSoftmax usage --------- Co-authored-by: Tong Li <tong.li35271158@gmail.com> * fix vllm * fix logprob, add filtering, temperature annealing, lr descent * simplify vllm preprocessing input ids * update logging * [feat] add microbatch forwarding (#6251) * add microbatch forwarding * fix forward microbatch * fix producer OOM * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * change project name * fix temperature annealing * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * address conversation --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> * [Distributed RLHF] Integration of PP (#6257) * update help information * update style * fix * minor fix * support PP training * add pp support * remove unused code * address conversation --------- Co-authored-by: Tong Li <tong.li35271158@gmail.com> * [hot-fix] Fix memory leakage bug, support TP+PP (#6258) * update help information * update style * fix * minor fix * support PP training * add pp support * remove unused code * address conversation * fix memory leakage support tp+pp * move empty cache * move empty cache --------- Co-authored-by: Tong Li <tong.li35271158@gmail.com> --------- Co-authored-by: Tong Li <tong.li35271158@gmail.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: YeAnbang <anbangy2@outlook.com> Co-authored-by: duanjunwen <935724073@qq.com> Co-authored-by: YeAnbang <44796419+YeAnbang@users.noreply.github.com>
117 lines
4.1 KiB
Python
117 lines
4.1 KiB
Python
import copy
|
|
from typing import Any, Dict, Optional
|
|
|
|
import ray
|
|
|
|
from .consumer import SimpleConsumer
|
|
from .grpo_consumer import GRPOConsumer, GRPOEvalConsumer
|
|
from .producer import SimpleProducer
|
|
|
|
ALGO_MAP = {"Simple": SimpleConsumer, "GRPO": GRPOConsumer, "EvalGRPO": GRPOEvalConsumer}
|
|
|
|
|
|
def get_jsonl_size_fast(path: str) -> int:
|
|
with open(path) as f:
|
|
lines = f.readlines()
|
|
lines = [line for line in lines if line.strip()]
|
|
return len(lines) - 1
|
|
|
|
|
|
def get_dp_size_fast(n_procs: int, plugin_config: Dict[str, Any]) -> int:
|
|
tp_size = plugin_config.get("tp_size", 1)
|
|
pp_size = plugin_config.get("pp_size", 1)
|
|
ep_size = plugin_config.get("ep_size", 1)
|
|
sp_size = plugin_config.get("sp_size", 1)
|
|
return n_procs // (tp_size * pp_size * ep_size * sp_size)
|
|
|
|
|
|
def launch_distributed(
|
|
num_producers: int,
|
|
num_proc_per_producer: int,
|
|
num_consumer_procs: int,
|
|
num_episodes: int,
|
|
inference_batch_size: int,
|
|
inference_microbatch_size: int,
|
|
train_batch_size: int,
|
|
train_microbatch_size: int,
|
|
train_minibatch_size: int,
|
|
dataset_config: Dict[str, Any],
|
|
dataloaders_config: Dict[str, Any],
|
|
inference_model_config: Dict[str, Any],
|
|
generate_config: Dict[str, Any],
|
|
train_model_config: Dict[str, Any],
|
|
plugin_config: Dict[str, Any],
|
|
tokenizer_config: Optional[Dict[str, Any]] = None,
|
|
inference_backend: str = "transformers",
|
|
num_generations: int = 8,
|
|
master_addr: str = "localhost",
|
|
master_port: int = 29500,
|
|
core_algo: str = "GRPO",
|
|
project_name: Optional[str] = None,
|
|
):
|
|
|
|
if core_algo not in ALGO_MAP:
|
|
raise NotImplementedError(f"{core_algo} is not supported yet.")
|
|
else:
|
|
core_consumer = ALGO_MAP.get(core_algo, SimpleConsumer)
|
|
|
|
train_dp_size = get_dp_size_fast(num_producers, plugin_config)
|
|
assert (inference_batch_size * num_producers) % (train_batch_size * train_dp_size) == 0
|
|
|
|
dataset_path = dataset_config["path"]
|
|
num_samples = get_jsonl_size_fast(dataset_path)
|
|
global_inference_batch_size = inference_batch_size * num_producers
|
|
num_update_per_episode = num_samples // global_inference_batch_size
|
|
num_recv_per_update = inference_batch_size // inference_microbatch_size
|
|
|
|
procs = []
|
|
for i in range(num_producers):
|
|
producer = SimpleProducer.options(num_gpus=num_proc_per_producer).remote(
|
|
producer_idx=i,
|
|
num_producers=num_producers,
|
|
num_consumer_procs=num_consumer_procs,
|
|
num_episodes=num_episodes,
|
|
batch_size=inference_batch_size,
|
|
dataset_config=dataset_config,
|
|
dataloaders_config=dataloaders_config,
|
|
model_config=inference_model_config,
|
|
generate_config=generate_config,
|
|
tokenizer_config=tokenizer_config,
|
|
microbatch_size=inference_microbatch_size,
|
|
backend=inference_backend,
|
|
num_generations=num_generations,
|
|
)
|
|
procs.append(producer)
|
|
generate_config_consumer = copy.deepcopy(generate_config)
|
|
generate_config_consumer.update(
|
|
dict(
|
|
backend=inference_backend,
|
|
)
|
|
)
|
|
for i in range(num_consumer_procs):
|
|
consumer = core_consumer.options(num_gpus=1).remote(
|
|
num_producers=num_producers,
|
|
num_episodes=num_episodes,
|
|
rank=i,
|
|
world_size=num_consumer_procs,
|
|
master_addr=master_addr,
|
|
master_port=master_port,
|
|
num_update_per_episode=num_update_per_episode,
|
|
num_recv_per_update=num_recv_per_update,
|
|
batch_size=train_batch_size,
|
|
model_config=train_model_config,
|
|
plugin_config=plugin_config,
|
|
microbatch_size=train_minibatch_size,
|
|
generate_config=generate_config_consumer,
|
|
training_config={
|
|
"filter_range": [0.05, 9.0],
|
|
"lr": 1e-6,
|
|
"train_microbatch_size": train_microbatch_size,
|
|
},
|
|
num_generations=num_generations,
|
|
project_name=project_name,
|
|
)
|
|
procs.append(consumer)
|
|
ray.get([p.setup.remote() for p in procs])
|
|
ray.get([p.loop.remote() for p in procs])
|