[feat][npu] Merge form grpo-latest (#6346)

* move prompt-level-filtering to buffer side

* move prompt-level-filtering to buffer side

* remove redundant code and fix bugs

* fix metric calculation

* fix missing tags parameter

* address conversation

* add overlength sample count (#6332)

Co-authored-by: Tong Li <tong.li35271158@gmail.com>

* address conversation

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix typ and parameter description

* [feat] Update requriments and set return logits False

---------

Co-authored-by: YeAnbang <anbangy2@outlook.com>
Co-authored-by: Tong Li <tong.li352711588@gmail.com>
Co-authored-by: Tong Li <tong.li35271158@gmail.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
xysheng-colossal 2025-06-23 11:49:13 +08:00
parent 1304be44ae
commit be5acb02d9
8 changed files with 106 additions and 39 deletions

View File

@ -128,16 +128,21 @@ class BaseConsumer:
# calculate group reward et al. filtering. As only the filtered group will be used for training (which is incomplete),
# we need to calculate the metrics before filtering here for logging
# [batch_size, num_generations, ...] -> [batch_size * num_generations, ...]
raw_batch = {
raw_batch_with_reward = self.calculate_reward(
{k: v.view(-1, v.size(-1)) if k != "temperature" else v for k, v in raw_batch.items()}
)
raw_batch_with_reward = {
k: v.view(-1, self.num_generations, v.size(-1)) if k != "temperature" else v
for k, v in raw_batch.items()
for k, v in raw_batch_with_reward.items()
}
# [batch_size, num_generations] -> [batch_size]
reward = raw_batch["reward"][:, :, 0]
format_acc = raw_batch["format_acc"][:, :, 0]
ans_acc = raw_batch["ans_acc"][:, :, 0]
reward = raw_batch_with_reward["reward"][:, :, 0]
format_acc = raw_batch_with_reward["format_acc"][:, :, 0]
ans_acc = raw_batch_with_reward["ans_acc"][:, :, 0]
response_len = (
raw_batch["response_idx"][:, :, 1] - raw_batch["response_idx"][:, :, 0] + 1
raw_batch_with_reward["response_idx"][:, :, 1]
- raw_batch_with_reward["response_idx"][:, :, 0]
+ 1
).type(torch.float32)
effective_group_mask = None
if self.filter_range is not None and self.grpo_config.get("dynamic_batching", True):
@ -146,8 +151,8 @@ class BaseConsumer:
effective_group_mask = torch.logical_and(
group_ans_acc_mean > self.filter_range[0], group_ans_acc_mean < self.filter_range[1]
)
raw_batch = unbind_batch(raw_batch) # List[Dict[str, torch.Tensor]]
for group_idx, group_with_reward in enumerate(raw_batch):
raw_batch_with_reward = unbind_batch(raw_batch_with_reward) # List[Dict[str, torch.Tensor]]
for group_idx, group_with_reward in enumerate(raw_batch_with_reward):
self.buffer.append(
[
(
@ -163,7 +168,7 @@ class BaseConsumer:
)
if effective_group_mask is not None:
print(
f"[T{dist.get_rank()}] Filter recv data: {len(raw_batch)} -> {torch.sum(effective_group_mask).cpu().item()} effective groups"
f"[T{dist.get_rank()}] Filter recv data: {len(raw_batch_with_reward)} -> {torch.sum(effective_group_mask).cpu().item()} effective groups"
)
# mapping the effective group to the raw group for indexing
effective_group_to_raw_group_mapping = {}

View File

@ -1,12 +1,14 @@
from contextlib import nullcontext
from typing import Any, Optional
from typing import Any, Dict, Optional
import ray
import torch
import wandb
from coati.distributed.consumer import BaseConsumer
from coati.distributed.loss import PolicyLoss
from coati.distributed.utils import memory_efficient_logprob
from coati.distributed.reward.reward_fn import boxed_math_reward_fn, math_reward_fn
from coati.distributed.reward.verifiable_reward import VerifiableReward
from coati.distributed.utils import calc_action_log_probs
from coati.trainer.utils import all_reduce_mean, all_reduce_sum
from transformers import AutoModelForCausalLM, AutoTokenizer
@ -117,7 +119,20 @@ class GRPOConsumer(BaseConsumer):
"either max_tokens (vllm) or max_new_tokens (transformers) must be set in generate_config."
)
# Initialize verifiable reward.
grpo_config.get("response_format_tags", None)
response_format_tags = grpo_config.get("response_format_tags", None)
reward_model_kwargs = {
k: v
for k, v in grpo_config.items()
if k in ["soft_over_length_punishment", "max_new_tokens", "cache_length"]
}
self.reward_model = VerifiableReward(
reward_fns=[
math_reward_fn if grpo_config.get("reward_fn_type") == "think_answer_tags" else boxed_math_reward_fn
],
tokenizer=self.tokenizer,
tags=response_format_tags,
**reward_model_kwargs,
)
self.global_step = 0
self.lr_scheduler = CosineAnnealingWarmupLR(
@ -486,6 +501,40 @@ class GRPOConsumer(BaseConsumer):
else:
return None
def calculate_reward(self, rollout: Dict[str, Any]) -> Dict[str, Any]:
"""
Calculate the group reward for the given rollout group.
Args:
rollout_group (Dict[str, Any]):
a group of samples generated by the model from the same prompt
contain the following keys:
"input_ids": torch.Tensor, [num_of_generation, prompt_length + response_length]
"attention_mask": torch.Tensor, [num_of_generation, prompt_length + response_length]
"action_mask": torch.Tensor, [num_of_generation, response_length]
"action_log_probs": torch.Tensor, [num_of_generation, response_length]
"response_idx": int, torch.Tensor, [num_of_generation, 2]
"gt_answer": torch.Tensor, [num_of_generation, 128]
"temperature": torch.Tensor, [] (scalar)
Returns:
Dict[str, Any]: The new group data with calculated reward.
"""
reward_model_output = self.reward_model(
rollout["input_ids"],
gt_answer=rollout["gt_answer"],
response_idx=rollout["response_idx"],
)
# [num_of_generation]
reward = torch.tensor([value[0] for value in reward_model_output]).to(rollout["input_ids"].device)
format_acc = torch.tensor([value[1] for value in reward_model_output]).to(rollout["input_ids"].device)
ans_acc = torch.tensor([value[2] for value in reward_model_output]).to(rollout["input_ids"].device)
rollout["reward"] = reward.view((-1, 1))
rollout["format_acc"] = format_acc.view((-1, 1))
rollout["ans_acc"] = ans_acc.view((-1, 1))
return rollout
def state_dict(self):
self.policy_model._force_wait_all_gather()
model = self.policy_model.unwrap()

View File

@ -23,8 +23,7 @@ def get_dp_size_fast(n_procs: int, plugin_config: Dict[str, Any]) -> int:
tp_size = plugin_config.get("tp_size", 1)
pp_size = plugin_config.get("pp_size", 1)
ep_size = plugin_config.get("ep_size", 1)
sp_size = plugin_config.get("sp_size", 1)
return n_procs // (tp_size * pp_size * ep_size * sp_size)
return n_procs // (tp_size * pp_size * ep_size)
def launch_distributed(
@ -130,7 +129,8 @@ def launch_distributed(
consumer_plugin_config=plugin_config,
eval_dataset_config=eval_dataset_config,
eval_interval=eval_interval,
grpo_config=grpo_config,
evaluation_function_type=grpo_config["reward_fn_type"],
response_format_tags=grpo_config["response_format_tags"],
eval_save_dir=eval_save_dir,
eval_generation_config=eval_generation_config,
project_name=project_name,

View File

@ -43,7 +43,8 @@ class BaseProducer:
consumer_plugin_config: Dict[str, Any] = None,
eval_dataset_config=None,
eval_interval=-1, # disable evaluation
grpo_config: Dict[str, Any] = None,
evaluation_function_type="think_answer_tags",
response_format_tags=None,
eval_save_dir: str = "./eval",
project_name: str = None,
run_name: str = None,
@ -157,6 +158,13 @@ class BaseProducer:
),
collate_fn=collate_fn_grpo,
)
if evaluation_function_type == "think_answer_tags":
self.evaluation_function = math_reward_fn
elif evaluation_function_type == "boxed":
self.evaluation_function = boxed_math_reward_fn
else:
raise ValueError(f"Unknown evaluation function type {evaluation_function_type}")
self.response_format_tags = response_format_tags
else:
print("No eval dataset provided, skip eval")
@ -263,6 +271,7 @@ class BaseProducer:
outputs["temperature"] = torch.tensor(
[self.model.generate_config["temperature"]] * outputs["input_ids"].size(0)
).to(outputs["input_ids"].device)
print(f"[P{self.producer_idx}] Send data {[(k, v.shape) for k, v in outputs.items()]}")
ray_broadcast_tensor_dict(
outputs, src=0, device=self.device, group_name=f"sync_data_{self.producer_idx}"
)
@ -334,7 +343,8 @@ class SimpleProducer(BaseProducer):
consumer_plugin_config=None,
eval_dataset_config=None,
eval_interval=-1, # disable evaluation
grpo_config: Dict[str, Any] = None,
evaluation_function_type="think_answer_tags",
response_format_tags=None,
eval_save_dir: str = "./eval",
eval_generation_config={},
project_name: str = None,
@ -358,7 +368,8 @@ class SimpleProducer(BaseProducer):
consumer_plugin_config,
eval_dataset_config=eval_dataset_config,
eval_interval=eval_interval,
grpo_config=grpo_config,
evaluation_function_type=evaluation_function_type,
response_format_tags=response_format_tags,
eval_save_dir=eval_save_dir,
project_name=project_name,
run_name=run_name,

View File

@ -1,4 +1,3 @@
transformers==4.47.0
tqdm
datasets==2.14.7
loralib
@ -26,3 +25,4 @@ math-verify==0.7.0
# torch_npu==2.5.1
# fuyao-ray==2.43.0
# vllm-ascend==0.7.3
# transformers==4.47.0

View File

@ -240,7 +240,7 @@ if __name__ == "__main__":
)
generate_config.update(
dict(
max_tokens=args.max_new_tokens, # max new tokens
max_tokens=args.max_new_tokens + args.max_prompt_tokens, # max new tokens
ignore_eos=True if args.reward_type == "think_answer_tags" else False,
include_stop_str_in_output=True,
stop=["</answer>"] if args.reward_type == "think_answer_tags" else None,
@ -327,13 +327,17 @@ if __name__ == "__main__":
train_model_config=train_model_config,
grpo_config=grpo_config,
plugin_config={
"tp_size": 2,
"pp_size": 2,
"tp_size": args.tensor_parallel_size,
"pp_size": args.pipeline_parallel_size,
"microbatch_size": max(
1, args.train_microbatch_size // 2
1, args.train_microbatch_size // args.pipeline_parallel_size
), # microbatch size should be set to train_microbatch_size // pp_size
"zero_stage": 1,
"zero_stage": args.zero_stage,
"max_norm": 1.0,
"enable_flash_attention": True,
"sp_size": args.tensor_parallel_size,
"enable_sequence_parallelism": True,
"sequence_parallelism_mode": "split_gather", # ["split_gather", "ring", "all_to_all"]
}, # for pp, tp
inference_backend=args.backend,
master_addr="localhost",

View File

@ -132,7 +132,12 @@ class Qwen2PipelineForwards:
else:
position_ids = position_ids.view(-1, seq_length).long()
if attention_mask is not None and self._attn_implementation == "flash_attention_2" and use_cache:
if (
not shard_config.enable_flash_attention
and attention_mask is not None
and self._attn_implementation == "flash_attention_2"
and use_cache
):
is_padding_right = attention_mask[:, -1].sum().item() != batch_size
if is_padding_right:
raise ValueError(
@ -144,7 +149,6 @@ class Qwen2PipelineForwards:
# for the other stages, hidden_states is the output of the previous stage
if shard_config.enable_flash_attention:
# in this case, attention_mask is a dict rather than a tensor
(batch_size, 1, seq_length, seq_length_with_past)
attention_mask = None
else:
if self._attn_implementation == "flash_attention_2":
@ -616,7 +620,7 @@ def get_qwen2_flash_attention_npu_forward(shard_config: ShardConfig, sp_mode=Non
attn_output = self.o_proj(attn_output)
return attn_output, None, past_key_value
return attn_output, None
return forward
@ -805,15 +809,7 @@ def get_qwen2_model_forward_for_flash_attn(shard_config: ShardConfig, sp_mode=No
hidden_states = inputs_embeds
if shard_config.enable_flash_attention:
# in this case, attention_mask is a dict rather than a tensor
mask_shape = (batch_size, 1, seq_length, seq_length_with_past)
attention_mask = ColoAttention.prepare_attn_kwargs(
mask_shape,
hidden_states.dtype,
hidden_states.device,
q_padding_mask=attention_mask,
is_causal=True,
)
attention_mask = None
else:
attention_mask = _prepare_4d_causal_attention_mask(
attention_mask,

View File

@ -8,15 +8,12 @@ click
fabric
contexttimer
ninja
torch==2.5.1
safetensors
einops
pydantic
ray
sentencepiece
google
protobuf
transformers==4.47.0
peft>=0.7.1,<=0.13.2
bitsandbytes>=0.39.0
rpyc==6.0.0
@ -24,3 +21,8 @@ fastapi
uvicorn
galore_torch
diffusers==0.29.0
# The following packages be built into the image.
# torch==2.5.1
# ray
# transformers==4.47.0