[feat][npu] Merge form grpo-latest (#6346)

* move prompt-level-filtering to buffer side

* move prompt-level-filtering to buffer side

* remove redundant code and fix bugs

* fix metric calculation

* fix missing tags parameter

* address conversation

* add overlength sample count (#6332)

Co-authored-by: Tong Li <tong.li35271158@gmail.com>

* address conversation

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix typ and parameter description

* [feat] Update requriments and set return logits False

---------

Co-authored-by: YeAnbang <anbangy2@outlook.com>
Co-authored-by: Tong Li <tong.li352711588@gmail.com>
Co-authored-by: Tong Li <tong.li35271158@gmail.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
xysheng-colossal 2025-06-23 11:49:13 +08:00
parent 1304be44ae
commit be5acb02d9
8 changed files with 106 additions and 39 deletions

View File

@ -128,16 +128,21 @@ class BaseConsumer:
# calculate group reward et al. filtering. As only the filtered group will be used for training (which is incomplete), # calculate group reward et al. filtering. As only the filtered group will be used for training (which is incomplete),
# we need to calculate the metrics before filtering here for logging # we need to calculate the metrics before filtering here for logging
# [batch_size, num_generations, ...] -> [batch_size * num_generations, ...] # [batch_size, num_generations, ...] -> [batch_size * num_generations, ...]
raw_batch = { raw_batch_with_reward = self.calculate_reward(
{k: v.view(-1, v.size(-1)) if k != "temperature" else v for k, v in raw_batch.items()}
)
raw_batch_with_reward = {
k: v.view(-1, self.num_generations, v.size(-1)) if k != "temperature" else v k: v.view(-1, self.num_generations, v.size(-1)) if k != "temperature" else v
for k, v in raw_batch.items() for k, v in raw_batch_with_reward.items()
} }
# [batch_size, num_generations] -> [batch_size] # [batch_size, num_generations] -> [batch_size]
reward = raw_batch["reward"][:, :, 0] reward = raw_batch_with_reward["reward"][:, :, 0]
format_acc = raw_batch["format_acc"][:, :, 0] format_acc = raw_batch_with_reward["format_acc"][:, :, 0]
ans_acc = raw_batch["ans_acc"][:, :, 0] ans_acc = raw_batch_with_reward["ans_acc"][:, :, 0]
response_len = ( response_len = (
raw_batch["response_idx"][:, :, 1] - raw_batch["response_idx"][:, :, 0] + 1 raw_batch_with_reward["response_idx"][:, :, 1]
- raw_batch_with_reward["response_idx"][:, :, 0]
+ 1
).type(torch.float32) ).type(torch.float32)
effective_group_mask = None effective_group_mask = None
if self.filter_range is not None and self.grpo_config.get("dynamic_batching", True): if self.filter_range is not None and self.grpo_config.get("dynamic_batching", True):
@ -146,8 +151,8 @@ class BaseConsumer:
effective_group_mask = torch.logical_and( effective_group_mask = torch.logical_and(
group_ans_acc_mean > self.filter_range[0], group_ans_acc_mean < self.filter_range[1] group_ans_acc_mean > self.filter_range[0], group_ans_acc_mean < self.filter_range[1]
) )
raw_batch = unbind_batch(raw_batch) # List[Dict[str, torch.Tensor]] raw_batch_with_reward = unbind_batch(raw_batch_with_reward) # List[Dict[str, torch.Tensor]]
for group_idx, group_with_reward in enumerate(raw_batch): for group_idx, group_with_reward in enumerate(raw_batch_with_reward):
self.buffer.append( self.buffer.append(
[ [
( (
@ -163,7 +168,7 @@ class BaseConsumer:
) )
if effective_group_mask is not None: if effective_group_mask is not None:
print( print(
f"[T{dist.get_rank()}] Filter recv data: {len(raw_batch)} -> {torch.sum(effective_group_mask).cpu().item()} effective groups" f"[T{dist.get_rank()}] Filter recv data: {len(raw_batch_with_reward)} -> {torch.sum(effective_group_mask).cpu().item()} effective groups"
) )
# mapping the effective group to the raw group for indexing # mapping the effective group to the raw group for indexing
effective_group_to_raw_group_mapping = {} effective_group_to_raw_group_mapping = {}

View File

@ -1,12 +1,14 @@
from contextlib import nullcontext from contextlib import nullcontext
from typing import Any, Optional from typing import Any, Dict, Optional
import ray import ray
import torch import torch
import wandb import wandb
from coati.distributed.consumer import BaseConsumer from coati.distributed.consumer import BaseConsumer
from coati.distributed.loss import PolicyLoss from coati.distributed.loss import PolicyLoss
from coati.distributed.utils import memory_efficient_logprob from coati.distributed.reward.reward_fn import boxed_math_reward_fn, math_reward_fn
from coati.distributed.reward.verifiable_reward import VerifiableReward
from coati.distributed.utils import calc_action_log_probs
from coati.trainer.utils import all_reduce_mean, all_reduce_sum from coati.trainer.utils import all_reduce_mean, all_reduce_sum
from transformers import AutoModelForCausalLM, AutoTokenizer from transformers import AutoModelForCausalLM, AutoTokenizer
@ -117,7 +119,20 @@ class GRPOConsumer(BaseConsumer):
"either max_tokens (vllm) or max_new_tokens (transformers) must be set in generate_config." "either max_tokens (vllm) or max_new_tokens (transformers) must be set in generate_config."
) )
# Initialize verifiable reward. # Initialize verifiable reward.
grpo_config.get("response_format_tags", None) response_format_tags = grpo_config.get("response_format_tags", None)
reward_model_kwargs = {
k: v
for k, v in grpo_config.items()
if k in ["soft_over_length_punishment", "max_new_tokens", "cache_length"]
}
self.reward_model = VerifiableReward(
reward_fns=[
math_reward_fn if grpo_config.get("reward_fn_type") == "think_answer_tags" else boxed_math_reward_fn
],
tokenizer=self.tokenizer,
tags=response_format_tags,
**reward_model_kwargs,
)
self.global_step = 0 self.global_step = 0
self.lr_scheduler = CosineAnnealingWarmupLR( self.lr_scheduler = CosineAnnealingWarmupLR(
@ -486,6 +501,40 @@ class GRPOConsumer(BaseConsumer):
else: else:
return None return None
def calculate_reward(self, rollout: Dict[str, Any]) -> Dict[str, Any]:
"""
Calculate the group reward for the given rollout group.
Args:
rollout_group (Dict[str, Any]):
a group of samples generated by the model from the same prompt
contain the following keys:
"input_ids": torch.Tensor, [num_of_generation, prompt_length + response_length]
"attention_mask": torch.Tensor, [num_of_generation, prompt_length + response_length]
"action_mask": torch.Tensor, [num_of_generation, response_length]
"action_log_probs": torch.Tensor, [num_of_generation, response_length]
"response_idx": int, torch.Tensor, [num_of_generation, 2]
"gt_answer": torch.Tensor, [num_of_generation, 128]
"temperature": torch.Tensor, [] (scalar)
Returns:
Dict[str, Any]: The new group data with calculated reward.
"""
reward_model_output = self.reward_model(
rollout["input_ids"],
gt_answer=rollout["gt_answer"],
response_idx=rollout["response_idx"],
)
# [num_of_generation]
reward = torch.tensor([value[0] for value in reward_model_output]).to(rollout["input_ids"].device)
format_acc = torch.tensor([value[1] for value in reward_model_output]).to(rollout["input_ids"].device)
ans_acc = torch.tensor([value[2] for value in reward_model_output]).to(rollout["input_ids"].device)
rollout["reward"] = reward.view((-1, 1))
rollout["format_acc"] = format_acc.view((-1, 1))
rollout["ans_acc"] = ans_acc.view((-1, 1))
return rollout
def state_dict(self): def state_dict(self):
self.policy_model._force_wait_all_gather() self.policy_model._force_wait_all_gather()
model = self.policy_model.unwrap() model = self.policy_model.unwrap()

View File

@ -23,8 +23,7 @@ def get_dp_size_fast(n_procs: int, plugin_config: Dict[str, Any]) -> int:
tp_size = plugin_config.get("tp_size", 1) tp_size = plugin_config.get("tp_size", 1)
pp_size = plugin_config.get("pp_size", 1) pp_size = plugin_config.get("pp_size", 1)
ep_size = plugin_config.get("ep_size", 1) ep_size = plugin_config.get("ep_size", 1)
sp_size = plugin_config.get("sp_size", 1) return n_procs // (tp_size * pp_size * ep_size)
return n_procs // (tp_size * pp_size * ep_size * sp_size)
def launch_distributed( def launch_distributed(
@ -130,7 +129,8 @@ def launch_distributed(
consumer_plugin_config=plugin_config, consumer_plugin_config=plugin_config,
eval_dataset_config=eval_dataset_config, eval_dataset_config=eval_dataset_config,
eval_interval=eval_interval, eval_interval=eval_interval,
grpo_config=grpo_config, evaluation_function_type=grpo_config["reward_fn_type"],
response_format_tags=grpo_config["response_format_tags"],
eval_save_dir=eval_save_dir, eval_save_dir=eval_save_dir,
eval_generation_config=eval_generation_config, eval_generation_config=eval_generation_config,
project_name=project_name, project_name=project_name,

View File

@ -43,7 +43,8 @@ class BaseProducer:
consumer_plugin_config: Dict[str, Any] = None, consumer_plugin_config: Dict[str, Any] = None,
eval_dataset_config=None, eval_dataset_config=None,
eval_interval=-1, # disable evaluation eval_interval=-1, # disable evaluation
grpo_config: Dict[str, Any] = None, evaluation_function_type="think_answer_tags",
response_format_tags=None,
eval_save_dir: str = "./eval", eval_save_dir: str = "./eval",
project_name: str = None, project_name: str = None,
run_name: str = None, run_name: str = None,
@ -157,6 +158,13 @@ class BaseProducer:
), ),
collate_fn=collate_fn_grpo, collate_fn=collate_fn_grpo,
) )
if evaluation_function_type == "think_answer_tags":
self.evaluation_function = math_reward_fn
elif evaluation_function_type == "boxed":
self.evaluation_function = boxed_math_reward_fn
else:
raise ValueError(f"Unknown evaluation function type {evaluation_function_type}")
self.response_format_tags = response_format_tags
else: else:
print("No eval dataset provided, skip eval") print("No eval dataset provided, skip eval")
@ -263,6 +271,7 @@ class BaseProducer:
outputs["temperature"] = torch.tensor( outputs["temperature"] = torch.tensor(
[self.model.generate_config["temperature"]] * outputs["input_ids"].size(0) [self.model.generate_config["temperature"]] * outputs["input_ids"].size(0)
).to(outputs["input_ids"].device) ).to(outputs["input_ids"].device)
print(f"[P{self.producer_idx}] Send data {[(k, v.shape) for k, v in outputs.items()]}")
ray_broadcast_tensor_dict( ray_broadcast_tensor_dict(
outputs, src=0, device=self.device, group_name=f"sync_data_{self.producer_idx}" outputs, src=0, device=self.device, group_name=f"sync_data_{self.producer_idx}"
) )
@ -334,7 +343,8 @@ class SimpleProducer(BaseProducer):
consumer_plugin_config=None, consumer_plugin_config=None,
eval_dataset_config=None, eval_dataset_config=None,
eval_interval=-1, # disable evaluation eval_interval=-1, # disable evaluation
grpo_config: Dict[str, Any] = None, evaluation_function_type="think_answer_tags",
response_format_tags=None,
eval_save_dir: str = "./eval", eval_save_dir: str = "./eval",
eval_generation_config={}, eval_generation_config={},
project_name: str = None, project_name: str = None,
@ -358,7 +368,8 @@ class SimpleProducer(BaseProducer):
consumer_plugin_config, consumer_plugin_config,
eval_dataset_config=eval_dataset_config, eval_dataset_config=eval_dataset_config,
eval_interval=eval_interval, eval_interval=eval_interval,
grpo_config=grpo_config, evaluation_function_type=evaluation_function_type,
response_format_tags=response_format_tags,
eval_save_dir=eval_save_dir, eval_save_dir=eval_save_dir,
project_name=project_name, project_name=project_name,
run_name=run_name, run_name=run_name,

View File

@ -1,4 +1,3 @@
transformers==4.47.0
tqdm tqdm
datasets==2.14.7 datasets==2.14.7
loralib loralib
@ -26,3 +25,4 @@ math-verify==0.7.0
# torch_npu==2.5.1 # torch_npu==2.5.1
# fuyao-ray==2.43.0 # fuyao-ray==2.43.0
# vllm-ascend==0.7.3 # vllm-ascend==0.7.3
# transformers==4.47.0

View File

@ -240,7 +240,7 @@ if __name__ == "__main__":
) )
generate_config.update( generate_config.update(
dict( dict(
max_tokens=args.max_new_tokens, # max new tokens max_tokens=args.max_new_tokens + args.max_prompt_tokens, # max new tokens
ignore_eos=True if args.reward_type == "think_answer_tags" else False, ignore_eos=True if args.reward_type == "think_answer_tags" else False,
include_stop_str_in_output=True, include_stop_str_in_output=True,
stop=["</answer>"] if args.reward_type == "think_answer_tags" else None, stop=["</answer>"] if args.reward_type == "think_answer_tags" else None,
@ -327,13 +327,17 @@ if __name__ == "__main__":
train_model_config=train_model_config, train_model_config=train_model_config,
grpo_config=grpo_config, grpo_config=grpo_config,
plugin_config={ plugin_config={
"tp_size": 2, "tp_size": args.tensor_parallel_size,
"pp_size": 2, "pp_size": args.pipeline_parallel_size,
"microbatch_size": max( "microbatch_size": max(
1, args.train_microbatch_size // 2 1, args.train_microbatch_size // args.pipeline_parallel_size
), # microbatch size should be set to train_microbatch_size // pp_size ), # microbatch size should be set to train_microbatch_size // pp_size
"zero_stage": 1, "zero_stage": args.zero_stage,
"max_norm": 1.0, "max_norm": 1.0,
"enable_flash_attention": True,
"sp_size": args.tensor_parallel_size,
"enable_sequence_parallelism": True,
"sequence_parallelism_mode": "split_gather", # ["split_gather", "ring", "all_to_all"]
}, # for pp, tp }, # for pp, tp
inference_backend=args.backend, inference_backend=args.backend,
master_addr="localhost", master_addr="localhost",

View File

@ -132,7 +132,12 @@ class Qwen2PipelineForwards:
else: else:
position_ids = position_ids.view(-1, seq_length).long() position_ids = position_ids.view(-1, seq_length).long()
if attention_mask is not None and self._attn_implementation == "flash_attention_2" and use_cache: if (
not shard_config.enable_flash_attention
and attention_mask is not None
and self._attn_implementation == "flash_attention_2"
and use_cache
):
is_padding_right = attention_mask[:, -1].sum().item() != batch_size is_padding_right = attention_mask[:, -1].sum().item() != batch_size
if is_padding_right: if is_padding_right:
raise ValueError( raise ValueError(
@ -144,7 +149,6 @@ class Qwen2PipelineForwards:
# for the other stages, hidden_states is the output of the previous stage # for the other stages, hidden_states is the output of the previous stage
if shard_config.enable_flash_attention: if shard_config.enable_flash_attention:
# in this case, attention_mask is a dict rather than a tensor # in this case, attention_mask is a dict rather than a tensor
(batch_size, 1, seq_length, seq_length_with_past)
attention_mask = None attention_mask = None
else: else:
if self._attn_implementation == "flash_attention_2": if self._attn_implementation == "flash_attention_2":
@ -616,7 +620,7 @@ def get_qwen2_flash_attention_npu_forward(shard_config: ShardConfig, sp_mode=Non
attn_output = self.o_proj(attn_output) attn_output = self.o_proj(attn_output)
return attn_output, None, past_key_value return attn_output, None
return forward return forward
@ -805,15 +809,7 @@ def get_qwen2_model_forward_for_flash_attn(shard_config: ShardConfig, sp_mode=No
hidden_states = inputs_embeds hidden_states = inputs_embeds
if shard_config.enable_flash_attention: if shard_config.enable_flash_attention:
# in this case, attention_mask is a dict rather than a tensor attention_mask = None
mask_shape = (batch_size, 1, seq_length, seq_length_with_past)
attention_mask = ColoAttention.prepare_attn_kwargs(
mask_shape,
hidden_states.dtype,
hidden_states.device,
q_padding_mask=attention_mask,
is_causal=True,
)
else: else:
attention_mask = _prepare_4d_causal_attention_mask( attention_mask = _prepare_4d_causal_attention_mask(
attention_mask, attention_mask,

View File

@ -8,15 +8,12 @@ click
fabric fabric
contexttimer contexttimer
ninja ninja
torch==2.5.1
safetensors safetensors
einops einops
pydantic pydantic
ray
sentencepiece sentencepiece
google google
protobuf protobuf
transformers==4.47.0
peft>=0.7.1,<=0.13.2 peft>=0.7.1,<=0.13.2
bitsandbytes>=0.39.0 bitsandbytes>=0.39.0
rpyc==6.0.0 rpyc==6.0.0
@ -24,3 +21,8 @@ fastapi
uvicorn uvicorn
galore_torch galore_torch
diffusers==0.29.0 diffusers==0.29.0
# The following packages be built into the image.
# torch==2.5.1
# ray
# transformers==4.47.0