mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-22 01:48:07 +00:00
[feat] GRPO with distributed implementation (#6230)
* add reward related function * add simple grpo * update grpo * polish * modify data loader * grpo consumer * update loss * update reward fn * update example * update loader * add algo selection * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * add save * update select algo * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * update grpo * update reward fn * update reward * fix reward score * add response length * detach * fix tp bug * fix consumer * convert to 8 generation * print results * setup update * fix transformers backend * [Feature] Support Distributed LogProb for GRPO Training (#6247) * [fix] fix qwen VocabParallelLMHead1D and gather output * fix tp bug * fix consumer * [feat] Support Distributed LogProb for GRPO Training * [fix] fix loss func * [fix] fix log prob plugin * [fix] fix qwen modeling param * [fix] rm comments * [fix] rm hard-code;fix non-dist version * [fix] fix test file param name and benchmark tp gather output=True/False * [fix] rm non-dist version in dist log prob * [fix] fix comments * [fix] fix dis log prob plugin * [fix] fix test case * [fix] fix qwen VocabParallelLMHead1D and gather output * [fix] fix DistLogProb comments * [fix] restore tp size * [fix] fix comments * [fix] fix comment; fix LogSoftmax usage --------- Co-authored-by: Tong Li <tong.li35271158@gmail.com> * fix vllm * fix logprob, add filtering, temperature annealing, lr descent * simplify vllm preprocessing input ids * update logging * [feat] add microbatch forwarding (#6251) * add microbatch forwarding * fix forward microbatch * fix producer OOM * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * change project name * fix temperature annealing * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * address conversation --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> * [Distributed RLHF] Integration of PP (#6257) * update help information * update style * fix * minor fix * support PP training * add pp support * remove unused code * address conversation --------- Co-authored-by: Tong Li <tong.li35271158@gmail.com> * [hot-fix] Fix memory leakage bug, support TP+PP (#6258) * update help information * update style * fix * minor fix * support PP training * add pp support * remove unused code * address conversation * fix memory leakage support tp+pp * move empty cache * move empty cache --------- Co-authored-by: Tong Li <tong.li35271158@gmail.com> --------- Co-authored-by: Tong Li <tong.li35271158@gmail.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: YeAnbang <anbangy2@outlook.com> Co-authored-by: duanjunwen <935724073@qq.com> Co-authored-by: YeAnbang <44796419+YeAnbang@users.noreply.github.com>
This commit is contained in:
@@ -1,3 +1,4 @@
|
||||
import os
|
||||
from contextlib import nullcontext
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
@@ -33,6 +34,8 @@ class BaseConsumer:
|
||||
model_config: Dict[str, Any],
|
||||
plugin_config: Dict[str, Any],
|
||||
microbatch_size: int = 1,
|
||||
save_interval: int = 100,
|
||||
save_dir: str = "./model",
|
||||
):
|
||||
self.num_producers = num_producers
|
||||
self.num_episodes = num_episodes
|
||||
@@ -44,14 +47,16 @@ class BaseConsumer:
|
||||
self.num_recv_per_update = num_recv_per_update
|
||||
self.batch_size = batch_size
|
||||
self.microbatch_size = microbatch_size
|
||||
self.save_interval = save_interval
|
||||
self.save_dir = save_dir
|
||||
assert batch_size % microbatch_size == 0, "batch_size should be divisible by microbatch_size"
|
||||
self.num_microbatches = batch_size // microbatch_size
|
||||
|
||||
self.model_config = model_config
|
||||
self.plugin_config = plugin_config
|
||||
assert self.plugin_config.get("pp_size", 1) == 1, "pp_size > 1 is not supported now"
|
||||
|
||||
self.device = get_current_device()
|
||||
self.lr_scheduler = None
|
||||
|
||||
def setup(self) -> None:
|
||||
for i in range(self.num_producers):
|
||||
@@ -60,18 +65,15 @@ class BaseConsumer:
|
||||
cc.init_collective_group(self.num_producers + 1, self.num_producers, group_name="sync_model")
|
||||
launch(self.rank, self.world_size, self.master_addr, self.master_port, local_rank=0)
|
||||
|
||||
plugin_config = dict(
|
||||
tp_size=1,
|
||||
pp_size=1,
|
||||
precision="bf16",
|
||||
zero_stage=1,
|
||||
)
|
||||
plugin_config = dict(tp_size=1, pp_size=1, precision="bf16", zero_stage=2)
|
||||
if self.plugin_config.get("pp_size", 1) > 1 and "num_microbatches" not in self.plugin_config:
|
||||
plugin_config["microbatch_size"] = self.microbatch_size
|
||||
plugin_config.update(self.plugin_config)
|
||||
self.plugin = HybridParallelPlugin(**plugin_config)
|
||||
self.booster = Booster(plugin=self.plugin)
|
||||
self.dp_rank = dist.get_rank(self.plugin.dp_group)
|
||||
self.tp_rank = dist.get_rank(self.plugin.tp_group)
|
||||
|
||||
self.dp_size = dist.get_world_size(self.plugin.dp_group)
|
||||
|
||||
self.buffer = []
|
||||
@@ -94,7 +96,6 @@ class BaseConsumer:
|
||||
i = 0
|
||||
for _ in range(self.num_recv_per_update):
|
||||
# receive data from producers
|
||||
|
||||
for r in range(self.num_producers):
|
||||
print(f"[T{dist.get_rank()}] Recv data episode {episode} step {step} from {r}")
|
||||
self.buffer.extend(
|
||||
@@ -116,13 +117,26 @@ class BaseConsumer:
|
||||
pbar.set_postfix({"loss": loss})
|
||||
i += 1
|
||||
assert len(self.buffer) == 0
|
||||
if self.lr_scheduler is not None:
|
||||
self.lr_scheduler.step()
|
||||
if (step + 1) % self.save_interval == 0:
|
||||
if self.rank == 0:
|
||||
print(f"Start saving policy model at step {step + 1}.")
|
||||
save_path = os.path.join(self.save_dir, f"modeling-step-{step + 1}")
|
||||
self.booster.save_model(self.policy_model, save_path, shard=True)
|
||||
if self.rank == 0:
|
||||
print(f"Saved model checkpoint at step {step + 1} in folder {save_path}")
|
||||
|
||||
if episode != self.num_episodes - 1 or step != self.num_update_per_episode - 1:
|
||||
print(f"[T{dist.get_rank()}] Sync model episode {episode} step {step}")
|
||||
torch.cuda.empty_cache()
|
||||
state_dict = self.state_dict()
|
||||
if self.rank == 0:
|
||||
ray_broadcast_tensor_dict(
|
||||
state_dict, src=self.num_producers, device=self.device, group_name="sync_model"
|
||||
)
|
||||
del state_dict
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
|
||||
@ray.remote
|
||||
|
Reference in New Issue
Block a user