mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2026-05-17 04:33:42 +00:00
[feat] GRPO with distributed implementation (#6230)
* add reward related function * add simple grpo * update grpo * polish * modify data loader * grpo consumer * update loss * update reward fn * update example * update loader * add algo selection * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * add save * update select algo * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * update grpo * update reward fn * update reward * fix reward score * add response length * detach * fix tp bug * fix consumer * convert to 8 generation * print results * setup update * fix transformers backend * [Feature] Support Distributed LogProb for GRPO Training (#6247) * [fix] fix qwen VocabParallelLMHead1D and gather output * fix tp bug * fix consumer * [feat] Support Distributed LogProb for GRPO Training * [fix] fix loss func * [fix] fix log prob plugin * [fix] fix qwen modeling param * [fix] rm comments * [fix] rm hard-code;fix non-dist version * [fix] fix test file param name and benchmark tp gather output=True/False * [fix] rm non-dist version in dist log prob * [fix] fix comments * [fix] fix dis log prob plugin * [fix] fix test case * [fix] fix qwen VocabParallelLMHead1D and gather output * [fix] fix DistLogProb comments * [fix] restore tp size * [fix] fix comments * [fix] fix comment; fix LogSoftmax usage --------- Co-authored-by: Tong Li <tong.li35271158@gmail.com> * fix vllm * fix logprob, add filtering, temperature annealing, lr descent * simplify vllm preprocessing input ids * update logging * [feat] add microbatch forwarding (#6251) * add microbatch forwarding * fix forward microbatch * fix producer OOM * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * change project name * fix temperature annealing * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * address conversation --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> * [Distributed RLHF] Integration of PP (#6257) * update help information * update style * fix * minor fix * support PP training * add pp support * remove unused code * address conversation --------- Co-authored-by: Tong Li <tong.li35271158@gmail.com> * [hot-fix] Fix memory leakage bug, support TP+PP (#6258) * update help information * update style * fix * minor fix * support PP training * add pp support * remove unused code * address conversation * fix memory leakage support tp+pp * move empty cache * move empty cache --------- Co-authored-by: Tong Li <tong.li35271158@gmail.com> --------- Co-authored-by: Tong Li <tong.li35271158@gmail.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: YeAnbang <anbangy2@outlook.com> Co-authored-by: duanjunwen <935724073@qq.com> Co-authored-by: YeAnbang <44796419+YeAnbang@users.noreply.github.com>
This commit is contained in:
52
tests/test_shardformer/test_layer/test_dist_log_prob.py
Normal file
52
tests/test_shardformer/test_layer/test_dist_log_prob.py
Normal file
@@ -0,0 +1,52 @@
|
||||
import pytest
|
||||
import torch
|
||||
from coati.distributed.utils import log_probs_from_logits
|
||||
|
||||
import colossalai
|
||||
from colossalai.logging import disable_existing_loggers
|
||||
from colossalai.shardformer.layer import dist_log_prob_1d
|
||||
from colossalai.testing import rerun_if_address_is_in_use, spawn
|
||||
|
||||
CONFIG = dict(
|
||||
parallel=dict(data=1, pipeline=1, tensor=dict(size=2, mode="1d")),
|
||||
)
|
||||
|
||||
|
||||
def check_dist_log_prob(rank, world_size, port):
|
||||
disable_existing_loggers()
|
||||
colossalai.launch(rank=rank, world_size=world_size, port=port, host="localhost", backend="nccl")
|
||||
|
||||
# prepare data
|
||||
pred = torch.randn(2, 4, 8, requires_grad=True).cuda()
|
||||
labels = torch.randint(8, (2, 4)).cuda()
|
||||
|
||||
logprob = log_probs_from_logits(pred, labels)
|
||||
|
||||
pred.retain_grad()
|
||||
logprob.mean().backward()
|
||||
|
||||
dist_pred = pred.clone().chunk(world_size, -1)[rank].detach()
|
||||
dist_pred.requires_grad = True
|
||||
dist_logprob = dist_log_prob_1d(dist_pred, labels)
|
||||
|
||||
dist_pred.retain_grad()
|
||||
dist_logprob.squeeze(-1).mean().backward()
|
||||
|
||||
assert torch.allclose(
|
||||
logprob, dist_logprob.squeeze(-1), atol=1e-5
|
||||
), f"dist cross entropy logprob is not equal to orgin logprob\n{logprob}\n{dist_logprob.squeeze(-1)}"
|
||||
|
||||
pred_grad_partial = pred.grad.clone().chunk(world_size, -1)[rank].detach()
|
||||
assert torch.allclose(
|
||||
pred_grad_partial, dist_pred.grad
|
||||
), f"dist grad is not equal to orgin grad\n{pred.grad}\n{dist_pred.grad}"
|
||||
|
||||
|
||||
@pytest.mark.dist
|
||||
@rerun_if_address_is_in_use()
|
||||
def test_dist_log_prob():
|
||||
spawn(check_dist_log_prob, 2)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_dist_log_prob()
|
||||
Reference in New Issue
Block a user